diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs
new file mode 100644
index 000000000..20c878bbb
--- /dev/null
+++ b/.git-blame-ignore-revs
@@ -0,0 +1,2 @@
+# Scala Steward: Reformat with scalafmt 3.8.6
+67ab2b447c399dc16378de4d527930cc5b7f1c7a
diff --git a/.scalafmt.conf b/.scalafmt.conf
index 038f3d925..771bfd31a 100644
--- a/.scalafmt.conf
+++ b/.scalafmt.conf
@@ -1,4 +1,4 @@
-version = 3.8.1
+version = 3.8.6
runner.dialect = scala213
newlines.beforeMultilineDef = keep
diff --git a/build.sbt b/build.sbt
index 4f8064cb3..77370f693 100644
--- a/build.sbt
+++ b/build.sbt
@@ -244,7 +244,10 @@ lazy val datasetSettings =
)
},
coverageExcludedPackages := "org.apache.spark.sql.reflection",
- libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude ("org.apache.hadoop", "hadoop-commons")
+ libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude (
+ "org.apache.hadoop",
+ "hadoop-commons"
+ )
)
lazy val refinedSettings =
diff --git a/cats/src/main/scala/frameless/cats/FramelessSyntax.scala b/cats/src/main/scala/frameless/cats/FramelessSyntax.scala
index 663ae5958..5bc5d63e7 100644
--- a/cats/src/main/scala/frameless/cats/FramelessSyntax.scala
+++ b/cats/src/main/scala/frameless/cats/FramelessSyntax.scala
@@ -7,18 +7,25 @@ import _root_.cats.mtl.Ask
import org.apache.spark.sql.SparkSession
trait FramelessSyntax extends frameless.FramelessSyntax {
- implicit class SparkJobOps[F[_], A](fa: F[A])(implicit S: Sync[F], A: Ask[F, SparkSession]) {
+
+ implicit class SparkJobOps[F[_], A](
+ fa: F[A]
+ )(implicit
+ S: Sync[F],
+ A: Ask[F, SparkSession]) {
import S._, A._
def withLocalProperty(key: String, value: String): F[A] =
for {
session <- ask
- _ <- delay(session.sparkContext.setLocalProperty(key, value))
- a <- fa
+ _ <- delay(session.sparkContext.setLocalProperty(key, value))
+ a <- fa
} yield a
- def withGroupId(groupId: String): F[A] = withLocalProperty("spark.jobGroup.id", groupId)
+ def withGroupId(groupId: String): F[A] =
+ withLocalProperty("spark.jobGroup.id", groupId)
- def withDescription(description: String): F[A] = withLocalProperty("spark.job.description", description)
+ def withDescription(description: String): F[A] =
+ withLocalProperty("spark.job.description", description)
}
}
diff --git a/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala b/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala
index 524c44117..ecd893925 100644
--- a/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala
+++ b/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala
@@ -5,7 +5,15 @@ import _root_.cats.effect.Sync
import org.apache.spark.sql.SparkSession
trait SparkDelayInstances {
- implicit def framelessCatsSparkDelayForSync[F[_]](implicit S: Sync[F]): SparkDelay[F] = new SparkDelay[F] {
- def delay[A](a: => A)(implicit spark: SparkSession): F[A] = S.delay(a)
+
+ implicit def framelessCatsSparkDelayForSync[F[_]](
+ implicit
+ S: Sync[F]
+ ): SparkDelay[F] = new SparkDelay[F] {
+ def delay[A](
+ a: => A
+ )(implicit
+ spark: SparkSession
+ ): F[A] = S.delay(a)
}
}
diff --git a/cats/src/main/scala/frameless/cats/SparkTask.scala b/cats/src/main/scala/frameless/cats/SparkTask.scala
index 3a6e6330b..166835a6a 100644
--- a/cats/src/main/scala/frameless/cats/SparkTask.scala
+++ b/cats/src/main/scala/frameless/cats/SparkTask.scala
@@ -6,6 +6,7 @@ import _root_.cats.data.Kleisli
import org.apache.spark.SparkContext
object SparkTask {
+
def apply[A](f: SparkContext => A): SparkTask[A] =
Kleisli[Id, SparkContext, A](f)
diff --git a/cats/src/main/scala/frameless/cats/implicits.scala b/cats/src/main/scala/frameless/cats/implicits.scala
index 1fa869a7f..fdfadf80d 100644
--- a/cats/src/main/scala/frameless/cats/implicits.scala
+++ b/cats/src/main/scala/frameless/cats/implicits.scala
@@ -2,7 +2,7 @@ package frameless
package cats
import _root_.cats._
-import _root_.cats.kernel.{CommutativeMonoid, CommutativeSemigroup}
+import _root_.cats.kernel.{ CommutativeMonoid, CommutativeSemigroup }
import _root_.cats.syntax.all._
import alleycats.Empty
@@ -10,42 +10,78 @@ import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD
object implicits extends FramelessSyntax with SparkDelayInstances {
+
implicit class rddOps[A: ClassTag](lhs: RDD[A]) {
- def csum(implicit m: CommutativeMonoid[A]): A =
+
+ def csum(
+ implicit
+ m: CommutativeMonoid[A]
+ ): A =
lhs.fold(m.empty)(_ |+| _)
- def csumOption(implicit m: CommutativeSemigroup[A]): Option[A] =
+
+ def csumOption(
+ implicit
+ m: CommutativeSemigroup[A]
+ ): Option[A] =
lhs.aggregate[Option[A]](None)(
(acc, a) => Some(acc.fold(a)(_ |+| a)),
(l, r) => l.fold(r)(x => r.map(_ |+| x) orElse Some(x))
)
- def cmin(implicit o: Order[A], e: Empty[A]): A = {
+ def cmin(implicit
+ o: Order[A],
+ e: Empty[A]
+ ): A = {
if (lhs.isEmpty()) e.empty
else lhs.reduce(_ min _)
}
- def cminOption(implicit o: Order[A]): Option[A] =
+
+ def cminOption(
+ implicit
+ o: Order[A]
+ ): Option[A] =
csumOption(new CommutativeSemigroup[A] {
def combine(l: A, r: A) = l min r
})
- def cmax(implicit o: Order[A], e: Empty[A]): A = {
+ def cmax(implicit
+ o: Order[A],
+ e: Empty[A]
+ ): A = {
if (lhs.isEmpty()) e.empty
else lhs.reduce(_ max _)
}
- def cmaxOption(implicit o: Order[A]): Option[A] =
+
+ def cmaxOption(
+ implicit
+ o: Order[A]
+ ): Option[A] =
csumOption(new CommutativeSemigroup[A] {
def combine(l: A, r: A) = l max r
})
}
implicit class pairRddOps[K: ClassTag, V: ClassTag](lhs: RDD[(K, V)]) {
- def csumByKey(implicit m: CommutativeSemigroup[V]): RDD[(K, V)] = lhs.reduceByKey(_ |+| _)
- def cminByKey(implicit o: Order[V]): RDD[(K, V)] = lhs.reduceByKey(_ min _)
- def cmaxByKey(implicit o: Order[V]): RDD[(K, V)] = lhs.reduceByKey(_ max _)
+
+ def csumByKey(
+ implicit
+ m: CommutativeSemigroup[V]
+ ): RDD[(K, V)] = lhs.reduceByKey(_ |+| _)
+
+ def cminByKey(
+ implicit
+ o: Order[V]
+ ): RDD[(K, V)] = lhs.reduceByKey(_ min _)
+
+ def cmaxByKey(
+ implicit
+ o: Order[V]
+ ): RDD[(K, V)] = lhs.reduceByKey(_ max _)
}
}
object union {
+
implicit def unionSemigroup[A]: Semigroup[RDD[A]] =
new Semigroup[RDD[A]] {
def combine(lhs: RDD[A], rhs: RDD[A]): RDD[A] = lhs union rhs
@@ -53,7 +89,11 @@ object union {
}
object inner {
- implicit def pairwiseInnerSemigroup[K: ClassTag, V: ClassTag: Semigroup]: Semigroup[RDD[(K, V)]] =
+
+ implicit def pairwiseInnerSemigroup[
+ K: ClassTag,
+ V: ClassTag: Semigroup
+ ]: Semigroup[RDD[(K, V)]] =
new Semigroup[RDD[(K, V)]] {
def combine(lhs: RDD[(K, V)], rhs: RDD[(K, V)]): RDD[(K, V)] =
lhs.join(rhs).mapValues { case (x, y) => x |+| y }
@@ -61,14 +101,18 @@ object inner {
}
object outer {
- implicit def pairwiseOuterSemigroup[K: ClassTag, V: ClassTag](implicit m: Monoid[V]): Semigroup[RDD[(K, V)]] =
+
+ implicit def pairwiseOuterSemigroup[K: ClassTag, V: ClassTag](
+ implicit
+ m: Monoid[V]
+ ): Semigroup[RDD[(K, V)]] =
new Semigroup[RDD[(K, V)]] {
def combine(lhs: RDD[(K, V)], rhs: RDD[(K, V)]): RDD[(K, V)] =
lhs.fullOuterJoin(rhs).mapValues {
case (Some(x), Some(y)) => x |+| y
- case (None, Some(y)) => y
- case (Some(x), None) => x
- case (None, None) => m.empty
+ case (None, Some(y)) => y
+ case (Some(x), None) => x
+ case (None, None) => m.empty
}
}
}
diff --git a/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala b/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala
index 95ffbed26..a5246ad37 100644
--- a/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala
+++ b/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala
@@ -6,16 +6,18 @@ import _root_.cats.effect.IO
import _root_.cats.effect.unsafe.implicits.global
import org.apache.spark.sql.SparkSession
import org.scalatest.matchers.should.Matchers
-import org.scalacheck.{Test => PTest}
+import org.scalacheck.{ Test => PTest }
import org.scalacheck.Prop, Prop._
import org.scalacheck.effect.PropF, PropF._
class FramelessSyntaxTests extends TypedDatasetSuite with Matchers {
override val sparkDelay = null
- def prop[A, B](data: Vector[X2[A, B]])(
- implicit ev: TypedEncoder[X2[A, B]]
- ): Prop = {
+ def prop[A, B](
+ data: Vector[X2[A, B]]
+ )(implicit
+ ev: TypedEncoder[X2[A, B]]
+ ): Prop = {
import implicits._
val dataset = TypedDataset.create(data).dataset
@@ -24,7 +26,13 @@ class FramelessSyntaxTests extends TypedDatasetSuite with Matchers {
val typedDataset = dataset.typed
val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]]
- typedDataset.collect[IO]().unsafeRunSync().toVector ?= typedDatasetFromDataFrame.collect[IO]().unsafeRunSync().toVector
+ typedDataset
+ .collect[IO]()
+ .unsafeRunSync()
+ .toVector ?= typedDatasetFromDataFrame
+ .collect[IO]()
+ .unsafeRunSync()
+ .toVector
}
test("dataset typed - toTyped") {
@@ -37,8 +45,7 @@ class FramelessSyntaxTests extends TypedDatasetSuite with Matchers {
forAllF { (k: String, v: String) =>
val scopedKey = "frameless.tests." + k
- 1
- .pure[ReaderT[IO, SparkSession, *]]
+ 1.pure[ReaderT[IO, SparkSession, *]]
.withLocalProperty(scopedKey, v)
.withGroupId(v)
.withDescription(v)
@@ -47,7 +54,8 @@ class FramelessSyntaxTests extends TypedDatasetSuite with Matchers {
sc.getLocalProperty(scopedKey) shouldBe v
sc.getLocalProperty("spark.jobGroup.id") shouldBe v
sc.getLocalProperty("spark.job.description") shouldBe v
- }.void
+ }
+ .void
}.check().unsafeRunSync().status shouldBe PTest.Passed
}
}
diff --git a/cats/src/test/scala/frameless/cats/test.scala b/cats/src/test/scala/frameless/cats/test.scala
index d75bc3bfd..d9314238f 100644
--- a/cats/src/test/scala/frameless/cats/test.scala
+++ b/cats/src/test/scala/frameless/cats/test.scala
@@ -7,7 +7,7 @@ import _root_.cats.syntax.all._
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD
-import org.apache.spark.{SparkConf, SparkContext => SC}
+import org.apache.spark.{ SparkConf, SparkContext => SC }
import org.scalatest.compatible.Assertion
import org.scalactic.anyvals.PosInt
@@ -21,7 +21,11 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.propspec.AnyPropSpec
trait SparkTests {
- val appID: String = new java.util.Date().toString + math.floor(math.random() * 10E4).toLong.toString
+
+ val appID: String = new java.util.Date().toString + math
+ .floor(math.random() * 10e4)
+ .toLong
+ .toString
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
@@ -29,16 +33,27 @@ trait SparkTests {
.set("spark.ui.enabled", "false")
.set("spark.app.id", appID)
- implicit def session: SparkSession = SparkSession.builder().config(conf).getOrCreate()
+ implicit def session: SparkSession =
+ SparkSession.builder().config(conf).getOrCreate()
implicit def sc: SparkContext = session.sparkContext
- implicit class seqToRdd[A: ClassTag](seq: Seq[A])(implicit sc: SC) {
+ implicit class seqToRdd[A: ClassTag](
+ seq: Seq[A]
+ )(implicit
+ sc: SC) {
def toRdd: RDD[A] = sc.makeRDD(seq)
}
}
object Tests {
- def innerPairwise(mx: Map[String, Int], my: Map[String, Int], check: (Any, Any) => Assertion)(implicit sc: SC): Assertion = {
+
+ def innerPairwise(
+ mx: Map[String, Int],
+ my: Map[String, Int],
+ check: (Any, Any) => Assertion
+ )(implicit
+ sc: SC
+ ): Assertion = {
import frameless.cats.implicits._
import frameless.cats.inner._
val xs = sc.parallelize(mx.toSeq)
@@ -63,18 +78,27 @@ object Tests {
}
}
-class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with SparkTests {
+class Test
+ extends AnyPropSpec
+ with Matchers
+ with ScalaCheckPropertyChecks
+ with SparkTests {
+
implicit override val generatorDrivenConfig =
PropertyCheckConfiguration(minSize = PosInt(10))
property("spark is working") {
- sc.parallelize(Seq(1, 2, 3)).collect() shouldBe Array(1,2,3)
+ sc.parallelize(Seq(1, 2, 3)).collect() shouldBe Array(1, 2, 3)
}
property("inner pairwise monoid") {
// Make sure we have non-empty map
- forAll { (xh: (String, Int), mx: Map[String, Int], yh: (String, Int), my: Map[String, Int]) =>
- Tests.innerPairwise(mx + xh, my + yh, _ shouldBe _)
+ forAll {
+ (xh: (String, Int),
+ mx: Map[String, Int],
+ yh: (String, Int),
+ my: Map[String, Int]
+ ) => Tests.innerPairwise(mx + xh, my + yh, _ shouldBe _)
}
}
@@ -110,7 +134,8 @@ class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with
property("rdd tuple commutative semigroup example") {
import frameless.cats.implicits._
forAll { seq: List[(Int, Int)] =>
- val expectedSum = if (seq.isEmpty) None else Some(Foldable[List].fold(seq))
+ val expectedSum =
+ if (seq.isEmpty) None else Some(Foldable[List].fold(seq))
val rdd = seq.toRdd
rdd.csum shouldBe expectedSum.getOrElse(0 -> 0)
@@ -120,10 +145,22 @@ class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with
property("pair rdd numeric commutative semigroup example") {
import frameless.cats.implicits._
- val seq = Seq( ("a",2), ("b",3), ("d",6), ("b",2), ("d",1) )
+ val seq = Seq(("a", 2), ("b", 3), ("d", 6), ("b", 2), ("d", 1))
val rdd = seq.toRdd
- rdd.cminByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",2), ("d",1) )
- rdd.cmaxByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",3), ("d",6) )
- rdd.csumByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",5), ("d",7) )
+ rdd.cminByKey.collect().toSeq should contain theSameElementsAs Seq(
+ ("a", 2),
+ ("b", 2),
+ ("d", 1)
+ )
+ rdd.cmaxByKey.collect().toSeq should contain theSameElementsAs Seq(
+ ("a", 2),
+ ("b", 3),
+ ("d", 6)
+ )
+ rdd.csumByKey.collect().toSeq should contain theSameElementsAs Seq(
+ ("a", 2),
+ ("b", 5),
+ ("d", 7)
+ )
}
}
diff --git a/core/src/main/scala/frameless/CatalystAverageable.scala b/core/src/main/scala/frameless/CatalystAverageable.scala
index 401ed65fc..60a9ee347 100644
--- a/core/src/main/scala/frameless/CatalystAverageable.scala
+++ b/core/src/main/scala/frameless/CatalystAverageable.scala
@@ -3,24 +3,35 @@ package frameless
import scala.annotation.implicitNotFound
/**
- * When averaging Spark doesn't change these types:
- * - BigDecimal -> BigDecimal
- * - Double -> Double
- * But it changes these types :
- * - Int -> Double
- * - Short -> Double
- * - Long -> Double
- */
+ * When averaging Spark doesn't change these types:
+ * - BigDecimal -> BigDecimal
+ * - Double -> Double
+ * But it changes these types :
+ * - Int -> Double
+ * - Short -> Double
+ * - Long -> Double
+ */
@implicitNotFound("Cannot compute average of type ${In}.")
trait CatalystAverageable[In, Out]
object CatalystAverageable {
private[this] val theInstance = new CatalystAverageable[Any, Any] {}
- private[this] def of[In, Out]: CatalystAverageable[In, Out] = theInstance.asInstanceOf[CatalystAverageable[In, Out]]
- implicit val framelessAverageableBigDecimal: CatalystAverageable[BigDecimal, BigDecimal] = of[BigDecimal, BigDecimal]
- implicit val framelessAverageableDouble: CatalystAverageable[Double, Double] = of[Double, Double]
- implicit val framelessAverageableLong: CatalystAverageable[Long, Double] = of[Long, Double]
- implicit val framelessAverageableInt: CatalystAverageable[Int, Double] = of[Int, Double]
- implicit val framelessAverageableShort: CatalystAverageable[Short, Double] = of[Short, Double]
+ private[this] def of[In, Out]: CatalystAverageable[In, Out] =
+ theInstance.asInstanceOf[CatalystAverageable[In, Out]]
+
+ implicit val framelessAverageableBigDecimal: CatalystAverageable[BigDecimal, BigDecimal] =
+ of[BigDecimal, BigDecimal]
+
+ implicit val framelessAverageableDouble: CatalystAverageable[Double, Double] =
+ of[Double, Double]
+
+ implicit val framelessAverageableLong: CatalystAverageable[Long, Double] =
+ of[Long, Double]
+
+ implicit val framelessAverageableInt: CatalystAverageable[Int, Double] =
+ of[Int, Double]
+
+ implicit val framelessAverageableShort: CatalystAverageable[Short, Double] =
+ of[Short, Double]
}
diff --git a/core/src/main/scala/frameless/CatalystBitShift.scala b/core/src/main/scala/frameless/CatalystBitShift.scala
index 753a61907..3e1cdbefa 100644
--- a/core/src/main/scala/frameless/CatalystBitShift.scala
+++ b/core/src/main/scala/frameless/CatalystBitShift.scala
@@ -2,19 +2,29 @@ package frameless
import scala.annotation.implicitNotFound
-/** Spark does not return always Int on shift
- */
+/**
+ * Spark does not return always Int on shift
+ */
@implicitNotFound("Cannot do bit shift operations on columns of type ${In}.")
trait CatalystBitShift[In, Out]
object CatalystBitShift {
private[this] val theInstance = new CatalystBitShift[Any, Any] {}
- private[this] def of[In, Out]: CatalystBitShift[In, Out] = theInstance.asInstanceOf[CatalystBitShift[In, Out]]
- implicit val framelessBitShiftBigDecimal: CatalystBitShift[BigDecimal, Int] = of[BigDecimal, Int]
- implicit val framelessBitShiftDouble : CatalystBitShift[Byte, Int] = of[Byte, Int]
- implicit val framelessBitShiftInt : CatalystBitShift[Short, Int] = of[Short, Int]
- implicit val framelessBitShiftLong : CatalystBitShift[Int, Int] = of[Int, Int]
- implicit val framelessBitShiftShort : CatalystBitShift[Long, Long] = of[Long, Long]
+ private[this] def of[In, Out]: CatalystBitShift[In, Out] =
+ theInstance.asInstanceOf[CatalystBitShift[In, Out]]
+
+ implicit val framelessBitShiftBigDecimal: CatalystBitShift[BigDecimal, Int] =
+ of[BigDecimal, Int]
+
+ implicit val framelessBitShiftDouble: CatalystBitShift[Byte, Int] =
+ of[Byte, Int]
+
+ implicit val framelessBitShiftInt: CatalystBitShift[Short, Int] =
+ of[Short, Int]
+ implicit val framelessBitShiftLong: CatalystBitShift[Int, Int] = of[Int, Int]
+
+ implicit val framelessBitShiftShort: CatalystBitShift[Long, Long] =
+ of[Long, Long]
}
diff --git a/core/src/main/scala/frameless/CatalystCast.scala b/core/src/main/scala/frameless/CatalystCast.scala
index 1a8a21573..94fa59a69 100644
--- a/core/src/main/scala/frameless/CatalystCast.scala
+++ b/core/src/main/scala/frameless/CatalystCast.scala
@@ -4,29 +4,62 @@ trait CatalystCast[A, B]
object CatalystCast {
private[this] val theInstance = new CatalystCast[Any, Any] {}
- private[this] def of[A, B]: CatalystCast[A, B] = theInstance.asInstanceOf[CatalystCast[A, B]]
+
+ private[this] def of[A, B]: CatalystCast[A, B] =
+ theInstance.asInstanceOf[CatalystCast[A, B]]
implicit def framelessCastToString[T]: CatalystCast[T, String] = of[T, String]
- implicit def framelessNumericToLong [A: CatalystNumeric]: CatalystCast[A, Long] = of[A, Long]
- implicit def framelessNumericToInt [A: CatalystNumeric]: CatalystCast[A, Int] = of[A, Int]
- implicit def framelessNumericToShort [A: CatalystNumeric]: CatalystCast[A, Short] = of[A, Short]
- implicit def framelessNumericToByte [A: CatalystNumeric]: CatalystCast[A, Byte] = of[A, Byte]
- implicit def framelessNumericToDecimal[A: CatalystNumeric]: CatalystCast[A, BigDecimal] = of[A, BigDecimal]
- implicit def framelessNumericToDouble [A: CatalystNumeric]: CatalystCast[A, Double] = of[A, Double]
+ implicit def framelessNumericToLong[
+ A: CatalystNumeric
+ ]: CatalystCast[A, Long] = of[A, Long]
+
+ implicit def framelessNumericToInt[A: CatalystNumeric]: CatalystCast[A, Int] =
+ of[A, Int]
+
+ implicit def framelessNumericToShort[
+ A: CatalystNumeric
+ ]: CatalystCast[A, Short] = of[A, Short]
+
+ implicit def framelessNumericToByte[
+ A: CatalystNumeric
+ ]: CatalystCast[A, Byte] = of[A, Byte]
- implicit def framelessBooleanToNumeric[A: CatalystNumeric]: CatalystCast[Boolean, A] = of[Boolean, A]
+ implicit def framelessNumericToDecimal[
+ A: CatalystNumeric
+ ]: CatalystCast[A, BigDecimal] = of[A, BigDecimal]
+
+ implicit def framelessNumericToDouble[
+ A: CatalystNumeric
+ ]: CatalystCast[A, Double] = of[A, Double]
+
+ implicit def framelessBooleanToNumeric[
+ A: CatalystNumeric
+ ]: CatalystCast[Boolean, A] = of[Boolean, A]
// doesn't make any sense to include:
// - sqlDateToBoolean: always None
// - sqlTimestampToBoolean: compares us to 0
- implicit val framelessStringToBoolean : CatalystCast[String, Option[Boolean]] = of[String, Option[Boolean]]
- implicit val framelessLongToBoolean : CatalystCast[Long, Boolean] = of[Long, Boolean]
- implicit val framelessIntToBoolean : CatalystCast[Int, Boolean] = of[Int, Boolean]
- implicit val framelessShortToBoolean : CatalystCast[Short, Boolean] = of[Short, Boolean]
- implicit val framelessByteToBoolean : CatalystCast[Byte, Boolean] = of[Byte, Boolean]
- implicit val framelessBigDecimalToBoolean: CatalystCast[BigDecimal, Boolean] = of[BigDecimal, Boolean]
- implicit val framelessDoubleToBoolean : CatalystCast[Double, Boolean] = of[Double, Boolean]
+ implicit val framelessStringToBoolean: CatalystCast[String, Option[Boolean]] =
+ of[String, Option[Boolean]]
+
+ implicit val framelessLongToBoolean: CatalystCast[Long, Boolean] =
+ of[Long, Boolean]
+
+ implicit val framelessIntToBoolean: CatalystCast[Int, Boolean] =
+ of[Int, Boolean]
+
+ implicit val framelessShortToBoolean: CatalystCast[Short, Boolean] =
+ of[Short, Boolean]
+
+ implicit val framelessByteToBoolean: CatalystCast[Byte, Boolean] =
+ of[Byte, Boolean]
+
+ implicit val framelessBigDecimalToBoolean: CatalystCast[BigDecimal, Boolean] =
+ of[BigDecimal, Boolean]
+
+ implicit val framelessDoubleToBoolean: CatalystCast[Double, Boolean] =
+ of[Double, Boolean]
// TODO
@@ -38,9 +71,8 @@ object CatalystCast {
// implicit object stringToLong extends CatalystCast[String, Option[Long]]
// implicit object stringToSqlDate extends CatalystCast[String, Option[SQLDate]]
-
// needs verification:
- //implicit object sqlTimestampToSqlDate extends CatalystCast[SQLTimestamp, SQLDate]
+ // implicit object sqlTimestampToSqlDate extends CatalystCast[SQLTimestamp, SQLDate]
// needs verification:
// implicit object sqlTimestampToDecimal extends CatalystCast[SQLTimestamp, BigDecimal]
diff --git a/core/src/main/scala/frameless/CatalystCollection.scala b/core/src/main/scala/frameless/CatalystCollection.scala
index 3456869a0..731a7385d 100644
--- a/core/src/main/scala/frameless/CatalystCollection.scala
+++ b/core/src/main/scala/frameless/CatalystCollection.scala
@@ -7,10 +7,12 @@ trait CatalystCollection[C[_]]
object CatalystCollection {
private[this] val theInstance = new CatalystCollection[Any] {}
- private[this] def of[A[_]]: CatalystCollection[A] = theInstance.asInstanceOf[CatalystCollection[A]]
- implicit val arrayObject : CatalystCollection[Array] = of[Array]
- implicit val seqObject : CatalystCollection[Seq] = of[Seq]
- implicit val listObject : CatalystCollection[List] = of[List]
+ private[this] def of[A[_]]: CatalystCollection[A] =
+ theInstance.asInstanceOf[CatalystCollection[A]]
+
+ implicit val arrayObject: CatalystCollection[Array] = of[Array]
+ implicit val seqObject: CatalystCollection[Seq] = of[Seq]
+ implicit val listObject: CatalystCollection[List] = of[List]
implicit val vectorObject: CatalystCollection[Vector] = of[Vector]
}
diff --git a/core/src/main/scala/frameless/CatalystDivisible.scala b/core/src/main/scala/frameless/CatalystDivisible.scala
index c9080a5d8..01d849174 100644
--- a/core/src/main/scala/frameless/CatalystDivisible.scala
+++ b/core/src/main/scala/frameless/CatalystDivisible.scala
@@ -2,20 +2,34 @@ package frameless
import scala.annotation.implicitNotFound
-/** Spark divides everything as Double, expect BigDecimals are divided into
- * another BigDecimal, benefiting from some added precision.
- */
+/**
+ * Spark divides everything as Double, expect BigDecimals are divided into
+ * another BigDecimal, benefiting from some added precision.
+ */
@implicitNotFound("Cannot compute division on type ${In}.")
trait CatalystDivisible[In, Out]
object CatalystDivisible {
private[this] val theInstance = new CatalystDivisible[Any, Any] {}
- private[this] def of[In, Out]: CatalystDivisible[In, Out] = theInstance.asInstanceOf[CatalystDivisible[In, Out]]
-
- implicit val framelessDivisibleBigDecimal: CatalystDivisible[BigDecimal, BigDecimal] = of[BigDecimal, BigDecimal]
- implicit val framelessDivisibleDouble : CatalystDivisible[Double, Double] = of[Double, Double]
- implicit val framelessDivisibleInt : CatalystDivisible[Int, Double] = of[Int, Double]
- implicit val framelessDivisibleLong : CatalystDivisible[Long, Double] = of[Long, Double]
- implicit val framelessDivisibleByte : CatalystDivisible[Byte, Double] = of[Byte, Double]
- implicit val framelessDivisibleShort : CatalystDivisible[Short, Double] = of[Short, Double]
+
+ private[this] def of[In, Out]: CatalystDivisible[In, Out] =
+ theInstance.asInstanceOf[CatalystDivisible[In, Out]]
+
+ implicit val framelessDivisibleBigDecimal: CatalystDivisible[BigDecimal, BigDecimal] =
+ of[BigDecimal, BigDecimal]
+
+ implicit val framelessDivisibleDouble: CatalystDivisible[Double, Double] =
+ of[Double, Double]
+
+ implicit val framelessDivisibleInt: CatalystDivisible[Int, Double] =
+ of[Int, Double]
+
+ implicit val framelessDivisibleLong: CatalystDivisible[Long, Double] =
+ of[Long, Double]
+
+ implicit val framelessDivisibleByte: CatalystDivisible[Byte, Double] =
+ of[Byte, Double]
+
+ implicit val framelessDivisibleShort: CatalystDivisible[Short, Double] =
+ of[Short, Double]
}
diff --git a/core/src/main/scala/frameless/CatalystIsin.scala b/core/src/main/scala/frameless/CatalystIsin.scala
index f630a7155..fe12ab622 100644
--- a/core/src/main/scala/frameless/CatalystIsin.scala
+++ b/core/src/main/scala/frameless/CatalystIsin.scala
@@ -8,11 +8,11 @@ trait CatalystIsin[A]
object CatalystIsin {
implicit object framelessBigDecimal extends CatalystIsin[BigDecimal]
- implicit object framelessByte extends CatalystIsin[Byte]
- implicit object framelessDouble extends CatalystIsin[Double]
- implicit object framelessFloat extends CatalystIsin[Float]
- implicit object framelessInt extends CatalystIsin[Int]
- implicit object framelessLong extends CatalystIsin[Long]
- implicit object framelessShort extends CatalystIsin[Short]
- implicit object framelesssString extends CatalystIsin[String]
+ implicit object framelessByte extends CatalystIsin[Byte]
+ implicit object framelessDouble extends CatalystIsin[Double]
+ implicit object framelessFloat extends CatalystIsin[Float]
+ implicit object framelessInt extends CatalystIsin[Int]
+ implicit object framelessLong extends CatalystIsin[Long]
+ implicit object framelessShort extends CatalystIsin[Short]
+ implicit object framelesssString extends CatalystIsin[String]
}
diff --git a/core/src/main/scala/frameless/CatalystNaN.scala b/core/src/main/scala/frameless/CatalystNaN.scala
index 3e7be8263..56549ed24 100644
--- a/core/src/main/scala/frameless/CatalystNaN.scala
+++ b/core/src/main/scala/frameless/CatalystNaN.scala
@@ -8,9 +8,10 @@ trait CatalystNaN[A]
object CatalystNaN {
private[this] val theInstance = new CatalystNaN[Any] {}
- private[this] def of[A]: CatalystNaN[A] = theInstance.asInstanceOf[CatalystNaN[A]]
- implicit val framelessFloatNaN : CatalystNaN[Float] = of[Float]
- implicit val framelessDoubleNaN : CatalystNaN[Double] = of[Double]
-}
+ private[this] def of[A]: CatalystNaN[A] =
+ theInstance.asInstanceOf[CatalystNaN[A]]
+ implicit val framelessFloatNaN: CatalystNaN[Float] = of[Float]
+ implicit val framelessDoubleNaN: CatalystNaN[Double] = of[Double]
+}
diff --git a/core/src/main/scala/frameless/CatalystNotNullable.scala b/core/src/main/scala/frameless/CatalystNotNullable.scala
index e8d4b3be1..58636455d 100644
--- a/core/src/main/scala/frameless/CatalystNotNullable.scala
+++ b/core/src/main/scala/frameless/CatalystNotNullable.scala
@@ -6,13 +6,20 @@ import scala.annotation.implicitNotFound
trait CatalystNullable[A]
object CatalystNullable {
- implicit def optionIsNullable[A]: CatalystNullable[Option[A]] = new CatalystNullable[Option[A]] {}
+
+ implicit def optionIsNullable[A]: CatalystNullable[Option[A]] =
+ new CatalystNullable[Option[A]] {}
}
@implicitNotFound("Cannot find evidence that type ${A} is not nullable.")
trait NotCatalystNullable[A]
object NotCatalystNullable {
- implicit def everythingIsNotNullable[A]: NotCatalystNullable[A] = new NotCatalystNullable[A] {}
- implicit def nullableIsNotNotNullable[A: CatalystNullable]: NotCatalystNullable[A] = new NotCatalystNullable[A] {}
+
+ implicit def everythingIsNotNullable[A]: NotCatalystNullable[A] =
+ new NotCatalystNullable[A] {}
+
+ implicit def nullableIsNotNotNullable[
+ A: CatalystNullable
+ ]: NotCatalystNullable[A] = new NotCatalystNullable[A] {}
}
diff --git a/core/src/main/scala/frameless/CatalystNumeric.scala b/core/src/main/scala/frameless/CatalystNumeric.scala
index c819ba2ae..fd3aa027e 100644
--- a/core/src/main/scala/frameless/CatalystNumeric.scala
+++ b/core/src/main/scala/frameless/CatalystNumeric.scala
@@ -8,12 +8,15 @@ trait CatalystNumeric[A]
object CatalystNumeric {
private[this] val theInstance = new CatalystNumeric[Any] {}
- private[this] def of[A]: CatalystNumeric[A] = theInstance.asInstanceOf[CatalystNumeric[A]]
- implicit val framelessbigDecimalNumeric: CatalystNumeric[BigDecimal] = of[BigDecimal]
- implicit val framelessbyteNumeric : CatalystNumeric[Byte] = of[Byte]
- implicit val framelessdoubleNumeric : CatalystNumeric[Double] = of[Double]
- implicit val framelessintNumeric : CatalystNumeric[Int] = of[Int]
- implicit val framelesslongNumeric : CatalystNumeric[Long] = of[Long]
- implicit val framelessshortNumeric : CatalystNumeric[Short] = of[Short]
+ private[this] def of[A]: CatalystNumeric[A] =
+ theInstance.asInstanceOf[CatalystNumeric[A]]
+
+ implicit val framelessbigDecimalNumeric: CatalystNumeric[BigDecimal] =
+ of[BigDecimal]
+ implicit val framelessbyteNumeric: CatalystNumeric[Byte] = of[Byte]
+ implicit val framelessdoubleNumeric: CatalystNumeric[Double] = of[Double]
+ implicit val framelessintNumeric: CatalystNumeric[Int] = of[Int]
+ implicit val framelesslongNumeric: CatalystNumeric[Long] = of[Long]
+ implicit val framelessshortNumeric: CatalystNumeric[Short] = of[Short]
}
diff --git a/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala b/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala
index 8fee63be2..08f69def0 100644
--- a/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala
+++ b/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala
@@ -2,20 +2,36 @@ package frameless
import scala.annotation.implicitNotFound
-/** Spark does not return always the same type as the input was for example abs
- */
+/**
+ * Spark does not return always the same type as the input was for example abs
+ */
@implicitNotFound("Cannot compute on type ${In}.")
trait CatalystNumericWithJavaBigDecimal[In, Out]
object CatalystNumericWithJavaBigDecimal {
- private[this] val theInstance = new CatalystNumericWithJavaBigDecimal[Any, Any] {}
- private[this] def of[In, Out]: CatalystNumericWithJavaBigDecimal[In, Out] = theInstance.asInstanceOf[CatalystNumericWithJavaBigDecimal[In, Out]]
- implicit val framelessAbsoluteBigDecimal: CatalystNumericWithJavaBigDecimal[BigDecimal, java.math.BigDecimal] = of[BigDecimal, java.math.BigDecimal]
- implicit val framelessAbsoluteDouble : CatalystNumericWithJavaBigDecimal[Double, Double] = of[Double, Double]
- implicit val framelessAbsoluteInt : CatalystNumericWithJavaBigDecimal[Int, Int] = of[Int, Int]
- implicit val framelessAbsoluteLong : CatalystNumericWithJavaBigDecimal[Long, Long] = of[Long, Long]
- implicit val framelessAbsoluteShort : CatalystNumericWithJavaBigDecimal[Short, Short] = of[Short, Short]
- implicit val framelessAbsoluteByte : CatalystNumericWithJavaBigDecimal[Byte, Byte] = of[Byte, Byte]
+ private[this] val theInstance =
+ new CatalystNumericWithJavaBigDecimal[Any, Any] {}
-}
\ No newline at end of file
+ private[this] def of[In, Out]: CatalystNumericWithJavaBigDecimal[In, Out] =
+ theInstance.asInstanceOf[CatalystNumericWithJavaBigDecimal[In, Out]]
+
+ implicit val framelessAbsoluteBigDecimal: CatalystNumericWithJavaBigDecimal[BigDecimal, java.math.BigDecimal] =
+ of[BigDecimal, java.math.BigDecimal]
+
+ implicit val framelessAbsoluteDouble: CatalystNumericWithJavaBigDecimal[Double, Double] =
+ of[Double, Double]
+
+ implicit val framelessAbsoluteInt: CatalystNumericWithJavaBigDecimal[Int, Int] =
+ of[Int, Int]
+
+ implicit val framelessAbsoluteLong: CatalystNumericWithJavaBigDecimal[Long, Long] =
+ of[Long, Long]
+
+ implicit val framelessAbsoluteShort: CatalystNumericWithJavaBigDecimal[Short, Short] =
+ of[Short, Short]
+
+ implicit val framelessAbsoluteByte: CatalystNumericWithJavaBigDecimal[Byte, Byte] =
+ of[Byte, Byte]
+
+}
diff --git a/core/src/main/scala/frameless/CatalystOrdered.scala b/core/src/main/scala/frameless/CatalystOrdered.scala
index e73604909..119ba248e 100644
--- a/core/src/main/scala/frameless/CatalystOrdered.scala
+++ b/core/src/main/scala/frameless/CatalystOrdered.scala
@@ -1,9 +1,9 @@
package frameless
import scala.annotation.implicitNotFound
-import shapeless.{Generic, HList, Lazy}
+import shapeless.{ Generic, HList, Lazy }
import shapeless.ops.hlist.LiftAll
-import java.time.{Duration, Instant, Period}
+import java.time.{ Duration, Instant, Period }
/** Types that can be ordered/compared by Catalyst. */
@implicitNotFound("Cannot compare columns of type ${A}.")
@@ -11,31 +11,39 @@ trait CatalystOrdered[A]
object CatalystOrdered {
private[this] val theInstance = new CatalystOrdered[Any] {}
- private[this] def of[A]: CatalystOrdered[A] = theInstance.asInstanceOf[CatalystOrdered[A]]
-
- implicit val framelessIntOrdered : CatalystOrdered[Int] = of[Int]
- implicit val framelessBooleanOrdered : CatalystOrdered[Boolean] = of[Boolean]
- implicit val framelessByteOrdered : CatalystOrdered[Byte] = of[Byte]
- implicit val framelessShortOrdered : CatalystOrdered[Short] = of[Short]
- implicit val framelessLongOrdered : CatalystOrdered[Long] = of[Long]
- implicit val framelessFloatOrdered : CatalystOrdered[Float] = of[Float]
- implicit val framelessDoubleOrdered : CatalystOrdered[Double] = of[Double]
- implicit val framelessBigDecimalOrdered : CatalystOrdered[BigDecimal] = of[BigDecimal]
- implicit val framelessSQLDateOrdered : CatalystOrdered[SQLDate] = of[SQLDate]
- implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] = of[SQLTimestamp]
- implicit val framelessStringOrdered : CatalystOrdered[String] = of[String]
- implicit val framelessInstantOrdered : CatalystOrdered[Instant] = of[Instant]
- implicit val framelessDurationOrdered : CatalystOrdered[Duration] = of[Duration]
- implicit val framelessPeriodOrdered : CatalystOrdered[Period] = of[Period]
-
- implicit def injectionOrdered[A, B]
- (implicit
+
+ private[this] def of[A]: CatalystOrdered[A] =
+ theInstance.asInstanceOf[CatalystOrdered[A]]
+
+ implicit val framelessIntOrdered: CatalystOrdered[Int] = of[Int]
+ implicit val framelessBooleanOrdered: CatalystOrdered[Boolean] = of[Boolean]
+ implicit val framelessByteOrdered: CatalystOrdered[Byte] = of[Byte]
+ implicit val framelessShortOrdered: CatalystOrdered[Short] = of[Short]
+ implicit val framelessLongOrdered: CatalystOrdered[Long] = of[Long]
+ implicit val framelessFloatOrdered: CatalystOrdered[Float] = of[Float]
+ implicit val framelessDoubleOrdered: CatalystOrdered[Double] = of[Double]
+
+ implicit val framelessBigDecimalOrdered: CatalystOrdered[BigDecimal] =
+ of[BigDecimal]
+ implicit val framelessSQLDateOrdered: CatalystOrdered[SQLDate] = of[SQLDate]
+
+ implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] =
+ of[SQLTimestamp]
+ implicit val framelessStringOrdered: CatalystOrdered[String] = of[String]
+ implicit val framelessInstantOrdered: CatalystOrdered[Instant] = of[Instant]
+
+ implicit val framelessDurationOrdered: CatalystOrdered[Duration] =
+ of[Duration]
+ implicit val framelessPeriodOrdered: CatalystOrdered[Period] = of[Period]
+
+ implicit def injectionOrdered[A, B](
+ implicit
i0: Injection[A, B],
i1: CatalystOrdered[B]
): CatalystOrdered[A] = of[A]
- implicit def deriveGeneric[G, H <: HList]
- (implicit
+ implicit def deriveGeneric[G, H <: HList](
+ implicit
i0: Generic.Aux[G, H],
i1: Lazy[LiftAll[CatalystOrdered, H]]
): CatalystOrdered[G] = of[G]
diff --git a/core/src/main/scala/frameless/CatalystPivotable.scala b/core/src/main/scala/frameless/CatalystPivotable.scala
index a7b34da64..091d82818 100644
--- a/core/src/main/scala/frameless/CatalystPivotable.scala
+++ b/core/src/main/scala/frameless/CatalystPivotable.scala
@@ -7,10 +7,14 @@ trait CatalystPivotable[A]
object CatalystPivotable {
private[this] val theInstance = new CatalystPivotable[Any] {}
- private[this] def of[A]: CatalystPivotable[A] = theInstance.asInstanceOf[CatalystPivotable[A]]
- implicit val framelessIntPivotable : CatalystPivotable[Int] = of[Int]
- implicit val framelessLongPivotable : CatalystPivotable[Long] = of[Long]
- implicit val framelessBooleanPivotable: CatalystPivotable[Boolean] = of[Boolean]
- implicit val framelessStringPivotable : CatalystPivotable[String] = of[String]
+ private[this] def of[A]: CatalystPivotable[A] =
+ theInstance.asInstanceOf[CatalystPivotable[A]]
+
+ implicit val framelessIntPivotable: CatalystPivotable[Int] = of[Int]
+ implicit val framelessLongPivotable: CatalystPivotable[Long] = of[Long]
+
+ implicit val framelessBooleanPivotable: CatalystPivotable[Boolean] =
+ of[Boolean]
+ implicit val framelessStringPivotable: CatalystPivotable[String] = of[String]
}
diff --git a/core/src/main/scala/frameless/CatalystRound.scala b/core/src/main/scala/frameless/CatalystRound.scala
index ee50b794a..92b5c1c54 100644
--- a/core/src/main/scala/frameless/CatalystRound.scala
+++ b/core/src/main/scala/frameless/CatalystRound.scala
@@ -2,18 +2,22 @@ package frameless
import scala.annotation.implicitNotFound
-/** Spark does not return always long on round
- */
+/**
+ * Spark does not return always long on round
+ */
@implicitNotFound("Cannot compute round on type ${In}.")
trait CatalystRound[In, Out]
object CatalystRound {
private[this] val theInstance = new CatalystRound[Any, Any] {}
- private[this] def of[In, Out]: CatalystRound[In, Out] = theInstance.asInstanceOf[CatalystRound[In, Out]]
- implicit val framelessBigDecimal: CatalystRound[BigDecimal, java.math.BigDecimal] = of[BigDecimal, java.math.BigDecimal]
- implicit val framelessDouble : CatalystRound[Double, Long] = of[Double, Long]
- implicit val framelessInt : CatalystRound[Int, Long] = of[Int, Long]
- implicit val framelessLong : CatalystRound[Long, Long] = of[Long, Long]
- implicit val framelessShort : CatalystRound[Short, Long] = of[Short, Long]
-}
\ No newline at end of file
+ private[this] def of[In, Out]: CatalystRound[In, Out] =
+ theInstance.asInstanceOf[CatalystRound[In, Out]]
+
+ implicit val framelessBigDecimal: CatalystRound[BigDecimal, java.math.BigDecimal] =
+ of[BigDecimal, java.math.BigDecimal]
+ implicit val framelessDouble: CatalystRound[Double, Long] = of[Double, Long]
+ implicit val framelessInt: CatalystRound[Int, Long] = of[Int, Long]
+ implicit val framelessLong: CatalystRound[Long, Long] = of[Long, Long]
+ implicit val framelessShort: CatalystRound[Short, Long] = of[Short, Long]
+}
diff --git a/core/src/main/scala/frameless/CatalystSummable.scala b/core/src/main/scala/frameless/CatalystSummable.scala
index 94010505e..8de1bb6f3 100644
--- a/core/src/main/scala/frameless/CatalystSummable.scala
+++ b/core/src/main/scala/frameless/CatalystSummable.scala
@@ -3,29 +3,39 @@ package frameless
import scala.annotation.implicitNotFound
/**
- * When summing Spark doesn't change these types:
- * - Long -> Long
- * - BigDecimal -> BigDecimal
- * - Double -> Double
- *
- * For other types there are conversions:
- * - Int -> Long
- * - Short -> Long
- */
+ * When summing Spark doesn't change these types:
+ * - Long -> Long
+ * - BigDecimal -> BigDecimal
+ * - Double -> Double
+ *
+ * For other types there are conversions:
+ * - Int -> Long
+ * - Short -> Long
+ */
@implicitNotFound("Cannot compute sum of type ${In}.")
trait CatalystSummable[In, Out] {
def zero: In
}
object CatalystSummable {
+
def apply[In, Out](zero: In): CatalystSummable[In, Out] = {
val _zero = zero
new CatalystSummable[In, Out] { val zero: In = _zero }
}
- implicit val framelessSummableLong : CatalystSummable[Long, Long] = CatalystSummable(zero = 0L)
- implicit val framelessSummableBigDecimal: CatalystSummable[BigDecimal, BigDecimal] = CatalystSummable(zero = BigDecimal(0))
- implicit val framelessSummableDouble : CatalystSummable[Double, Double] = CatalystSummable(zero = 0.0)
- implicit val framelessSummableInt : CatalystSummable[Int, Long] = CatalystSummable(zero = 0)
- implicit val framelessSummableShort : CatalystSummable[Short, Long] = CatalystSummable(zero = 0)
+ implicit val framelessSummableLong: CatalystSummable[Long, Long] =
+ CatalystSummable(zero = 0L)
+
+ implicit val framelessSummableBigDecimal: CatalystSummable[BigDecimal, BigDecimal] =
+ CatalystSummable(zero = BigDecimal(0))
+
+ implicit val framelessSummableDouble: CatalystSummable[Double, Double] =
+ CatalystSummable(zero = 0.0)
+
+ implicit val framelessSummableInt: CatalystSummable[Int, Long] =
+ CatalystSummable(zero = 0)
+
+ implicit val framelessSummableShort: CatalystSummable[Short, Long] =
+ CatalystSummable(zero = 0)
}
diff --git a/core/src/main/scala/frameless/CatalystVariance.scala b/core/src/main/scala/frameless/CatalystVariance.scala
index 9e843fa70..e9f935549 100644
--- a/core/src/main/scala/frameless/CatalystVariance.scala
+++ b/core/src/main/scala/frameless/CatalystVariance.scala
@@ -3,18 +3,22 @@ package frameless
import scala.annotation.implicitNotFound
/**
- * Spark's variance and stddev functions always return Double
- */
+ * Spark's variance and stddev functions always return Double
+ */
@implicitNotFound("Cannot compute variance on type ${A}.")
trait CatalystVariance[A]
object CatalystVariance {
private[this] val theInstance = new CatalystVariance[Any] {}
- private[this] def of[A]: CatalystVariance[A] = theInstance.asInstanceOf[CatalystVariance[A]]
- implicit val framelessIntVariance : CatalystVariance[Int] = of[Int]
- implicit val framelessLongVariance : CatalystVariance[Long] = of[Long]
- implicit val framelessShortVariance : CatalystVariance[Short] = of[Short]
- implicit val framelessBigDecimalVariance: CatalystVariance[BigDecimal] = of[BigDecimal]
- implicit val framelessDoubleVariance : CatalystVariance[Double] = of[Double]
+ private[this] def of[A]: CatalystVariance[A] =
+ theInstance.asInstanceOf[CatalystVariance[A]]
+
+ implicit val framelessIntVariance: CatalystVariance[Int] = of[Int]
+ implicit val framelessLongVariance: CatalystVariance[Long] = of[Long]
+ implicit val framelessShortVariance: CatalystVariance[Short] = of[Short]
+
+ implicit val framelessBigDecimalVariance: CatalystVariance[BigDecimal] =
+ of[BigDecimal]
+ implicit val framelessDoubleVariance: CatalystVariance[Double] = of[Double]
}
diff --git a/dataset/src/main/scala/frameless/FramelessSyntax.scala b/dataset/src/main/scala/frameless/FramelessSyntax.scala
index 5ba294921..3967314df 100644
--- a/dataset/src/main/scala/frameless/FramelessSyntax.scala
+++ b/dataset/src/main/scala/frameless/FramelessSyntax.scala
@@ -1,18 +1,25 @@
package frameless
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql.{ Column, DataFrame, Dataset }
trait FramelessSyntax {
+
implicit class ColumnSyntax(self: Column) {
- def typedColumn[T, U: TypedEncoder]: TypedColumn[T, U] = new TypedColumn[T, U](self)
- def typedAggregate[T, U: TypedEncoder]: TypedAggregate[T, U] = new TypedAggregate[T, U](self)
+
+ def typedColumn[T, U: TypedEncoder]: TypedColumn[T, U] =
+ new TypedColumn[T, U](self)
+
+ def typedAggregate[T, U: TypedEncoder]: TypedAggregate[T, U] =
+ new TypedAggregate[T, U](self)
}
implicit class DatasetSyntax[T: TypedEncoder](self: Dataset[T]) {
def typed: TypedDataset[T] = TypedDataset.create[T](self)
}
- implicit class DataframeSyntax(self: DataFrame){
- def unsafeTyped[T: TypedEncoder]: TypedDataset[T] = TypedDataset.createUnsafe(self)
+ implicit class DataframeSyntax(self: DataFrame) {
+
+ def unsafeTyped[T: TypedEncoder]: TypedDataset[T] =
+ TypedDataset.createUnsafe(self)
}
}
diff --git a/dataset/src/main/scala/frameless/InjectionEnum.scala b/dataset/src/main/scala/frameless/InjectionEnum.scala
index 4ed1006e3..677b321ef 100644
--- a/dataset/src/main/scala/frameless/InjectionEnum.scala
+++ b/dataset/src/main/scala/frameless/InjectionEnum.scala
@@ -3,6 +3,7 @@ package frameless
import shapeless._
trait InjectionEnum {
+
implicit val cnilInjectionEnum: Injection[CNil, String] =
Injection(
// $COVERAGE-OFF$No value of type CNil so impossible to test
@@ -15,10 +16,10 @@ trait InjectionEnum {
)
implicit def coproductInjectionEnum[H, T <: Coproduct](
- implicit
- typeable: Typeable[H] ,
- gen: Generic.Aux[H, HNil],
- tInjectionEnum: Injection[T, String]
+ implicit
+ typeable: Typeable[H],
+ gen: Generic.Aux[H, HNil],
+ tInjectionEnum: Injection[T, String]
): Injection[H :+: T, String] = {
val dataConstructorName = typeable.describe.takeWhile(_ != '.')
@@ -37,9 +38,9 @@ trait InjectionEnum {
}
implicit def genericInjectionEnum[A, R](
- implicit
- gen: Generic.Aux[A, R],
- rInjectionEnum: Injection[R, String]
+ implicit
+ gen: Generic.Aux[A, R],
+ rInjectionEnum: Injection[R, String]
): Injection[A, String] =
Injection(
value => rInjectionEnum(gen.to(value)),
diff --git a/dataset/src/main/scala/frameless/IsValueClass.scala b/dataset/src/main/scala/frameless/IsValueClass.scala
index 78605c130..c8097e785 100644
--- a/dataset/src/main/scala/frameless/IsValueClass.scala
+++ b/dataset/src/main/scala/frameless/IsValueClass.scala
@@ -5,13 +5,18 @@ import shapeless.labelled.FieldType
/** Evidence that `T` is a Value class */
@annotation.implicitNotFound(msg = "${T} is not a Value class")
-final class IsValueClass[T] private() {}
+final class IsValueClass[T] private () {}
object IsValueClass {
+
/** Provides an evidence `A` is a Value class */
- implicit def apply[A <: AnyVal, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil]](
- implicit
+ implicit def apply[
+ A <: AnyVal,
+ G <: ::[_, HNil],
+ H <: ::[_ <: FieldType[_ <: Symbol, _], HNil]
+ ](implicit
i0: LabelledGeneric.Aux[A, G],
- i1: DropUnitValues.Aux[G, H]): IsValueClass[A] = new IsValueClass[A]
+ i1: DropUnitValues.Aux[G, H]
+ ): IsValueClass[A] = new IsValueClass[A]
}
diff --git a/dataset/src/main/scala/frameless/Job.scala b/dataset/src/main/scala/frameless/Job.scala
index 40931b8b4..ead6ecf09 100644
--- a/dataset/src/main/scala/frameless/Job.scala
+++ b/dataset/src/main/scala/frameless/Job.scala
@@ -2,7 +2,10 @@ package frameless
import org.apache.spark.sql.SparkSession
-sealed abstract class Job[A](implicit spark: SparkSession) { self =>
+sealed abstract class Job[A](
+ implicit
+ spark: SparkSession) { self =>
+
/** Runs a new Spark job. */
def run(): A
@@ -32,13 +35,22 @@ sealed abstract class Job[A](implicit spark: SparkSession) { self =>
}
}
-
object Job {
- def apply[A](a: => A)(implicit spark: SparkSession): Job[A] = new Job[A] {
+
+ def apply[A](
+ a: => A
+ )(implicit
+ spark: SparkSession
+ ): Job[A] = new Job[A] {
def run(): A = a
}
- implicit val framelessSparkDelayForJob: SparkDelay[Job] = new SparkDelay[Job] {
- def delay[A](a: => A)(implicit spark: SparkSession): Job[A] = Job(a)
- }
+ implicit val framelessSparkDelayForJob: SparkDelay[Job] =
+ new SparkDelay[Job] {
+ def delay[A](
+ a: => A
+ )(implicit
+ spark: SparkSession
+ ): Job[A] = Job(a)
+ }
}
diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala
index 7427d9de0..269a4879a 100644
--- a/dataset/src/main/scala/frameless/RecordEncoder.scala
+++ b/dataset/src/main/scala/frameless/RecordEncoder.scala
@@ -4,7 +4,10 @@ import org.apache.spark.sql.FramelessInternals
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{
- Invoke, NewInstance, UnwrapOption, WrapOption
+ Invoke,
+ NewInstance,
+ UnwrapOption,
+ WrapOption
}
import org.apache.spark.sql.types._
@@ -16,10 +19,9 @@ import shapeless.ops.record.Keys
import scala.reflect.ClassTag
case class RecordEncoderField(
- ordinal: Int,
- name: String,
- encoder: TypedEncoder[_]
-)
+ ordinal: Int,
+ name: String,
+ encoder: TypedEncoder[_])
trait RecordEncoderFields[T <: HList] extends Serializable {
def value: List[RecordEncoderField]
@@ -30,33 +32,40 @@ trait RecordEncoderFields[T <: HList] extends Serializable {
object RecordEncoderFields {
- implicit def deriveRecordLast[K <: Symbol, H]
- (implicit
+ implicit def deriveRecordLast[K <: Symbol, H](
+ implicit
key: Witness.Aux[K],
head: RecordFieldEncoder[H]
- ): RecordEncoderFields[FieldType[K, H] :: HNil] = new RecordEncoderFields[FieldType[K, H] :: HNil] {
+ ): RecordEncoderFields[FieldType[K, H] :: HNil] =
+ new RecordEncoderFields[FieldType[K, H] :: HNil] {
def value: List[RecordEncoderField] = fieldEncoder[K, H] :: Nil
}
- implicit def deriveRecordCons[K <: Symbol, H, T <: HList]
- (implicit
+ implicit def deriveRecordCons[K <: Symbol, H, T <: HList](
+ implicit
key: Witness.Aux[K],
head: RecordFieldEncoder[H],
tail: RecordEncoderFields[T]
- ): RecordEncoderFields[FieldType[K, H] :: T] = new RecordEncoderFields[FieldType[K, H] :: T] {
+ ): RecordEncoderFields[FieldType[K, H] :: T] =
+ new RecordEncoderFields[FieldType[K, H] :: T] {
def value: List[RecordEncoderField] =
- fieldEncoder[K, H] :: tail.value.map(x => x.copy(ordinal = x.ordinal + 1))
- }
+ fieldEncoder[K, H] :: tail.value
+ .map(x => x.copy(ordinal = x.ordinal + 1))
+ }
- private def fieldEncoder[K <: Symbol, H](implicit key: Witness.Aux[K], e: RecordFieldEncoder[H]): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder)
+ private def fieldEncoder[K <: Symbol, H](
+ implicit
+ key: Witness.Aux[K],
+ e: RecordFieldEncoder[H]
+ ): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder)
}
/**
- * Assists the generation of constructor call parameters from a labelled generic representation.
- * As Unit typed fields were removed earlier, we need to put back unit literals in the appropriate positions.
- *
- * @tparam T labelled generic representation of type fields
- */
+ * Assists the generation of constructor call parameters from a labelled generic representation.
+ * As Unit typed fields were removed earlier, we need to put back unit literals in the appropriate positions.
+ *
+ * @tparam T labelled generic representation of type fields
+ */
trait NewInstanceExprs[T <: HList] extends Serializable {
def from(exprs: List[Expression]): Seq[Expression]
}
@@ -67,32 +76,41 @@ object NewInstanceExprs {
def from(exprs: List[Expression]): Seq[Expression] = Nil
}
- implicit def deriveUnit[K <: Symbol, T <: HList]
- (implicit
+ implicit def deriveUnit[K <: Symbol, T <: HList](
+ implicit
tail: NewInstanceExprs[T]
- ): NewInstanceExprs[FieldType[K, Unit] :: T] = new NewInstanceExprs[FieldType[K, Unit] :: T] {
+ ): NewInstanceExprs[FieldType[K, Unit] :: T] =
+ new NewInstanceExprs[FieldType[K, Unit] :: T] {
def from(exprs: List[Expression]): Seq[Expression] =
Literal.fromObject(()) +: tail.from(exprs)
}
- implicit def deriveNonUnit[K <: Symbol, V, T <: HList]
- (implicit
+ implicit def deriveNonUnit[K <: Symbol, V, T <: HList](
+ implicit
notUnit: V =:!= Unit,
tail: NewInstanceExprs[T]
- ): NewInstanceExprs[FieldType[K, V] :: T] = new NewInstanceExprs[FieldType[K, V] :: T] {
- def from(exprs: List[Expression]): Seq[Expression] = exprs.head +: tail.from(exprs.tail)
+ ): NewInstanceExprs[FieldType[K, V] :: T] =
+ new NewInstanceExprs[FieldType[K, V] :: T] {
+ def from(exprs: List[Expression]): Seq[Expression] =
+ exprs.head +: tail.from(exprs.tail)
}
}
/**
- * Drops fields with Unit type from labelled generic representation of types.
- *
- * @tparam L labelled generic representation of type fields
- */
-trait DropUnitValues[L <: HList] extends DepFn1[L] with Serializable { type Out <: HList }
+ * Drops fields with Unit type from labelled generic representation of types.
+ *
+ * @tparam L labelled generic representation of type fields
+ */
+trait DropUnitValues[L <: HList] extends DepFn1[L] with Serializable {
+ type Out <: HList
+}
object DropUnitValues {
- def apply[L <: HList](implicit dropUnitValues: DropUnitValues[L]): Aux[L, dropUnitValues.Out] = dropUnitValues
+
+ def apply[L <: HList](
+ implicit
+ dropUnitValues: DropUnitValues[L]
+ ): Aux[L, dropUnitValues.Out] = dropUnitValues
type Aux[L <: HList, Out0 <: HList] = DropUnitValues[L] { type Out = Out0 }
@@ -101,93 +119,94 @@ object DropUnitValues {
def apply(l: HNil): Out = HNil
}
- implicit def deriveUnit[K <: Symbol, T <: HList, OutT <: HList]
- (implicit
- dropUnitValues : DropUnitValues.Aux[T, OutT]
- ): Aux[FieldType[K, Unit] :: T, OutT] = new DropUnitValues[FieldType[K, Unit] :: T] {
+ implicit def deriveUnit[K <: Symbol, T <: HList, OutT <: HList](
+ implicit
+ dropUnitValues: DropUnitValues.Aux[T, OutT]
+ ): Aux[FieldType[K, Unit] :: T, OutT] =
+ new DropUnitValues[FieldType[K, Unit] :: T] {
type Out = OutT
- def apply(l : FieldType[K, Unit] :: T): Out = dropUnitValues(l.tail)
+ def apply(l: FieldType[K, Unit] :: T): Out = dropUnitValues(l.tail)
}
- implicit def deriveNonUnit[K <: Symbol, V, T <: HList, OutH, OutT <: HList]
- (implicit
+ implicit def deriveNonUnit[K <: Symbol, V, T <: HList, OutH, OutT <: HList](
+ implicit
nonUnit: V =:!= Unit,
- dropUnitValues : DropUnitValues.Aux[T, OutT]
- ): Aux[FieldType[K, V] :: T, FieldType[K, V] :: OutT] = new DropUnitValues[FieldType[K, V] :: T] {
+ dropUnitValues: DropUnitValues.Aux[T, OutT]
+ ): Aux[FieldType[K, V] :: T, FieldType[K, V] :: OutT] =
+ new DropUnitValues[FieldType[K, V] :: T] {
type Out = FieldType[K, V] :: OutT
- def apply(l : FieldType[K, V] :: T): Out = l.head :: dropUnitValues(l.tail)
+ def apply(l: FieldType[K, V] :: T): Out = l.head :: dropUnitValues(l.tail)
}
}
-class RecordEncoder[F, G <: HList, H <: HList]
- (implicit
+class RecordEncoder[F, G <: HList, H <: HList](
+ implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons[H],
fields: Lazy[RecordEncoderFields[H]],
newInstanceExprs: Lazy[NewInstanceExprs[G]],
- classTag: ClassTag[F]
- ) extends TypedEncoder[F] {
- def nullable: Boolean = false
-
- def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
-
- def catalystRepr: DataType = {
- val structFields = fields.value.value.map { field =>
- StructField(
- name = field.name,
- dataType = field.encoder.catalystRepr,
- nullable = field.encoder.nullable,
- metadata = Metadata.empty
- )
- }
-
- StructType(structFields)
+ classTag: ClassTag[F])
+ extends TypedEncoder[F] {
+ def nullable: Boolean = false
+
+ def jvmRepr: DataType = FramelessInternals.objectTypeFor[F]
+
+ def catalystRepr: DataType = {
+ val structFields = fields.value.value.map { field =>
+ StructField(
+ name = field.name,
+ dataType = field.encoder.catalystRepr,
+ nullable = field.encoder.nullable,
+ metadata = Metadata.empty
+ )
}
- def toCatalyst(path: Expression): Expression = {
- val nameExprs = fields.value.value.map { field =>
- Literal(field.name)
- }
-
- val valueExprs = fields.value.value.map { field =>
- val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
- field.encoder.toCatalyst(fieldPath)
- }
-
- // the way exprs are encoded in CreateNamedStruct
- val exprs = nameExprs.zip(valueExprs).flatMap {
- case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
- }
+ StructType(structFields)
+ }
- val createExpr = CreateNamedStruct(exprs)
- val nullExpr = Literal.create(null, createExpr.dataType)
+ def toCatalyst(path: Expression): Expression = {
+ val nameExprs = fields.value.value.map { field => Literal(field.name) }
- If(IsNull(path), nullExpr, createExpr)
+ val valueExprs = fields.value.value.map { field =>
+ val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil)
+ field.encoder.toCatalyst(fieldPath)
}
- def fromCatalyst(path: Expression): Expression = {
- val exprs = fields.value.value.map { field =>
- field.encoder.fromCatalyst(
- GetStructField(path, field.ordinal, Some(field.name)))
- }
+ // the way exprs are encoded in CreateNamedStruct
+ val exprs = nameExprs.zip(valueExprs).flatMap {
+ case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil
+ }
- val newArgs = newInstanceExprs.value.from(exprs)
- val newExpr = NewInstance(
- classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
+ val createExpr = CreateNamedStruct(exprs)
+ val nullExpr = Literal.create(null, createExpr.dataType)
- val nullExpr = Literal.create(null, jvmRepr)
+ If(IsNull(path), nullExpr, createExpr)
+ }
- If(IsNull(path), nullExpr, newExpr)
+ def fromCatalyst(path: Expression): Expression = {
+ val exprs = fields.value.value.map { field =>
+ field.encoder.fromCatalyst(
+ GetStructField(path, field.ordinal, Some(field.name))
+ )
}
+
+ val newArgs = newInstanceExprs.value.from(exprs)
+ val newExpr =
+ NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true)
+
+ val nullExpr = Literal.create(null, jvmRepr)
+
+ If(IsNull(path), nullExpr, newExpr)
+ }
}
final class RecordFieldEncoder[T](
- val encoder: TypedEncoder[T],
- private[frameless] val jvmRepr: DataType,
- private[frameless] val fromCatalyst: Expression => Expression,
- private[frameless] val toCatalyst: Expression => Expression
-) extends Serializable
+ val encoder: TypedEncoder[T],
+ private[frameless] val jvmRepr: DataType,
+ private[frameless] val fromCatalyst: Expression => Expression,
+ private[frameless] val toCatalyst: Expression => Expression)
+ extends Serializable
object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
@@ -198,8 +217,14 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
* @tparam K the key type for the fields
* @tparam V the inner value type
*/
- implicit def optionValueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]]
- (implicit
+ implicit def optionValueClass[
+ F: IsValueClass,
+ G <: ::[_, HNil],
+ H <: ::[_ <: FieldType[_ <: Symbol, _], HNil],
+ K <: Symbol,
+ V,
+ KS <: ::[_ <: Symbol, HNil]
+ ](implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil],
@@ -208,49 +233,49 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
i5: TypedEncoder[V],
i6: ClassTag[F]
): RecordFieldEncoder[Option[F]] = {
- val fieldName = i4.head(i3()).name
- val innerJvmRepr = ObjectType(i6.runtimeClass)
+ val fieldName = i4.head(i3()).name
+ val innerJvmRepr = ObjectType(i6.runtimeClass)
- val catalyst: Expression => Expression = { path =>
- val value = UnwrapOption(innerJvmRepr, path)
- val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil)
+ val catalyst: Expression => Expression = { path =>
+ val value = UnwrapOption(innerJvmRepr, path)
+ val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil)
- i5.toCatalyst(javaValue)
- }
+ i5.toCatalyst(javaValue)
+ }
- val fromCatalyst: Expression => Expression = { path =>
- val javaValue = i5.fromCatalyst(path)
- val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr)
+ val fromCatalyst: Expression => Expression = { path =>
+ val javaValue = i5.fromCatalyst(path)
+ val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr)
- WrapOption(value, innerJvmRepr)
- }
+ WrapOption(value, innerJvmRepr)
+ }
- val jvmr = ObjectType(classOf[Option[F]])
+ val jvmr = ObjectType(classOf[Option[F]])
- new RecordFieldEncoder[Option[F]](
- encoder = new TypedEncoder[Option[F]] {
- val nullable = true
+ new RecordFieldEncoder[Option[F]](
+ encoder = new TypedEncoder[Option[F]] {
+ val nullable = true
- val jvmRepr = jvmr
+ val jvmRepr = jvmr
- @inline def catalystRepr: DataType = i5.catalystRepr
+ @inline def catalystRepr: DataType = i5.catalystRepr
- def fromCatalyst(path: Expression): Expression = {
- val javaValue = i5.fromCatalyst(path)
- val value = NewInstance(
- i6.runtimeClass, Seq(javaValue), innerJvmRepr)
+ def fromCatalyst(path: Expression): Expression = {
+ val javaValue = i5.fromCatalyst(path)
+ val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr)
- WrapOption(value, innerJvmRepr)
- }
+ WrapOption(value, innerJvmRepr)
+ }
- def toCatalyst(path: Expression): Expression = catalyst(path)
+ def toCatalyst(path: Expression): Expression = catalyst(path)
- override def toString: String = s"RecordFieldEncoder.optionValueClass[${i6.runtimeClass.getName}]('${fieldName}', $i5)"
- },
- jvmRepr = jvmr,
- fromCatalyst = fromCatalyst,
- toCatalyst = catalyst
- )
+ override def toString: String =
+ s"RecordFieldEncoder.optionValueClass[${i6.runtimeClass.getName}]('${fieldName}', $i5)"
+ },
+ jvmRepr = jvmr,
+ fromCatalyst = fromCatalyst,
+ toCatalyst = catalyst
+ )
}
/**
@@ -259,8 +284,14 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
* @tparam H the single field of the value class (with guarantee it's not a `Unit` value)
* @tparam V the inner value type
*/
- implicit def valueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]]
- (implicit
+ implicit def valueClass[
+ F: IsValueClass,
+ G <: ::[_, HNil],
+ H <: ::[_ <: FieldType[_ <: Symbol, _], HNil],
+ K <: Symbol,
+ V,
+ KS <: ::[_ <: Symbol, HNil]
+ ](implicit
i0: LabelledGeneric.Aux[F, G],
i1: DropUnitValues.Aux[G, H],
i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil],
@@ -269,40 +300,47 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority {
i5: TypedEncoder[V],
i6: ClassTag[F]
): RecordFieldEncoder[F] = {
- val cls = i6.runtimeClass
- val jvmr = i5.jvmRepr
- val fieldName = i4.head(i3()).name
-
- new RecordFieldEncoder[F](
- encoder = new TypedEncoder[F] {
- def nullable = i5.nullable
-
- def jvmRepr = jvmr
-
- def catalystRepr: DataType = i5.catalystRepr
-
- def fromCatalyst(path: Expression): Expression =
- i5.fromCatalyst(path)
-
- @inline def toCatalyst(path: Expression): Expression =
- i5.toCatalyst(path)
-
- override def toString: String = s"RecordFieldEncoder.valueClass[${cls.getName}]('${fieldName}', ${i5})"
- },
- jvmRepr = FramelessInternals.objectTypeFor[F],
- fromCatalyst = { expr: Expression =>
- NewInstance(
- i6.runtimeClass,
- i5.fromCatalyst(expr) :: Nil,
- ObjectType(i6.runtimeClass))
- },
- toCatalyst = { expr: Expression =>
- i5.toCatalyst(Invoke(expr, fieldName, jvmr))
- }
- )
+ val cls = i6.runtimeClass
+ val jvmr = i5.jvmRepr
+ val fieldName = i4.head(i3()).name
+
+ new RecordFieldEncoder[F](
+ encoder = new TypedEncoder[F] {
+ def nullable = i5.nullable
+
+ def jvmRepr = jvmr
+
+ def catalystRepr: DataType = i5.catalystRepr
+
+ def fromCatalyst(path: Expression): Expression =
+ i5.fromCatalyst(path)
+
+ @inline def toCatalyst(path: Expression): Expression =
+ i5.toCatalyst(path)
+
+ override def toString: String =
+ s"RecordFieldEncoder.valueClass[${cls.getName}]('${fieldName}', ${i5})"
+ },
+ jvmRepr = FramelessInternals.objectTypeFor[F],
+ fromCatalyst = { expr: Expression =>
+ NewInstance(
+ i6.runtimeClass,
+ i5.fromCatalyst(expr) :: Nil,
+ ObjectType(i6.runtimeClass)
+ )
+ },
+ toCatalyst = { expr: Expression =>
+ i5.toCatalyst(Invoke(expr, fieldName, jvmr))
+ }
+ )
}
}
private[frameless] sealed trait RecordFieldEncoderLowPriority {
- implicit def apply[T](implicit e: TypedEncoder[T]): RecordFieldEncoder[T] = new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst)
+
+ implicit def apply[T](
+ implicit
+ e: TypedEncoder[T]
+ ): RecordFieldEncoder[T] =
+ new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst)
}
diff --git a/dataset/src/main/scala/frameless/SparkDelay.scala b/dataset/src/main/scala/frameless/SparkDelay.scala
index 74a651ae3..83e78d3c3 100644
--- a/dataset/src/main/scala/frameless/SparkDelay.scala
+++ b/dataset/src/main/scala/frameless/SparkDelay.scala
@@ -3,5 +3,10 @@ package frameless
import org.apache.spark.sql.SparkSession
trait SparkDelay[F[_]] {
- def delay[A](a: => A)(implicit spark: SparkSession): F[A]
+
+ def delay[A](
+ a: => A
+ )(implicit
+ spark: SparkSession
+ ): F[A]
}
diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala
index 0bbaf6fed..ef3fe52b0 100644
--- a/dataset/src/main/scala/frameless/TypedColumn.scala
+++ b/dataset/src/main/scala/frameless/TypedColumn.scala
@@ -1,11 +1,11 @@
package frameless
-import frameless.functions.{litAggr, lit => flit}
+import frameless.functions.{ litAggr, lit => flit }
import frameless.syntax._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.DecimalType
-import org.apache.spark.sql.{Column, FramelessInternals}
+import org.apache.spark.sql.{ Column, FramelessInternals }
import shapeless._
import shapeless.ops.record.Selector
@@ -21,91 +21,121 @@ sealed trait UntypedExpression[T] {
override def toString: String = expr.toString()
}
-/** Expression used in `select`-like constructions.
- */
-sealed class TypedColumn[T, U](expr: Expression)(
- implicit val uenc: TypedEncoder[U]
-) extends AbstractTypedColumn[T, U](expr) {
+/**
+ * Expression used in `select`-like constructions.
+ */
+sealed class TypedColumn[T, U](
+ expr: Expression
+ )(implicit
+ val uenc: TypedEncoder[U])
+ extends AbstractTypedColumn[T, U](expr) {
type ThisType[A, B] = TypedColumn[A, B]
- def this(column: Column)(implicit uencoder: TypedEncoder[U]) =
+ def this(
+ column: Column
+ )(implicit
+ uencoder: TypedEncoder[U]
+ ) =
this(FramelessInternals.expr(column))
- override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn
+ override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] =
+ c.typedColumn
override def lit[U1: TypedEncoder](c: U1): TypedColumn[T, U1] = flit(c)
}
-/** Expression used in `agg`-like constructions.
- */
-sealed class TypedAggregate[T, U](expr: Expression)(
- implicit val uenc: TypedEncoder[U]
-) extends AbstractTypedColumn[T, U](expr) {
+/**
+ * Expression used in `agg`-like constructions.
+ */
+sealed class TypedAggregate[T, U](
+ expr: Expression
+ )(implicit
+ val uenc: TypedEncoder[U])
+ extends AbstractTypedColumn[T, U](expr) {
type ThisType[A, B] = TypedAggregate[A, B]
- def this(column: Column)(implicit uencoder: TypedEncoder[U]) = {
+ def this(
+ column: Column
+ )(implicit
+ uencoder: TypedEncoder[U]
+ ) = {
this(FramelessInternals.expr(column))
}
- override def typed[W, U1: TypedEncoder](c: Column): TypedAggregate[W, U1] = c.typedAggregate
+ override def typed[W, U1: TypedEncoder](c: Column): TypedAggregate[W, U1] =
+ c.typedAggregate
override def lit[U1: TypedEncoder](c: U1): TypedAggregate[T, U1] = litAggr(c)
}
-/** Generic representation of a typed column. A typed column can either be a [[TypedAggregate]] or
- * a [[frameless.TypedColumn]].
- *
- * Documentation marked "apache/spark" is thanks to apache/spark Contributors
- * at https://github.com/apache/spark, licensed under Apache v2.0 available at
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * @tparam T phantom type representing the dataset on which this columns is
- * selected. When `T = A with B` the selection is on either A or B.
- * @tparam U type of column
- */
-abstract class AbstractTypedColumn[T, U]
- (val expr: Expression)
- (implicit val uencoder: TypedEncoder[U])
+/**
+ * Generic representation of a typed column. A typed column can either be a [[TypedAggregate]] or
+ * a [[frameless.TypedColumn]].
+ *
+ * Documentation marked "apache/spark" is thanks to apache/spark Contributors
+ * at https://github.com/apache/spark, licensed under Apache v2.0 available at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * @tparam T phantom type representing the dataset on which this columns is
+ * selected. When `T = A with B` the selection is on either A or B.
+ * @tparam U type of column
+ */
+abstract class AbstractTypedColumn[T, U](
+ val expr: Expression
+ )(implicit
+ val uencoder: TypedEncoder[U])
extends UntypedExpression[T] { self =>
type ThisType[A, B] <: AbstractTypedColumn[A, B]
- /** A helper class to make to simplify working with Optional fields.
- *
- * {{{
- * val x: TypedColumn[Option[Int]] = _
- * x.opt.map(_*2) // This only compiles if the type of x is Option[X] (in this example X is of type Int)
- * }}}
- *
- * @note Known issue: map() will NOT work when the applied function is a udf().
- * It will compile and then throw a runtime error.
- **/
+ /**
+ * A helper class to make to simplify working with Optional fields.
+ *
+ * {{{
+ * val x: TypedColumn[Option[Int]] = _
+ * x.opt.map(_*2) // This only compiles if the type of x is Option[X] (in this example X is of type Int)
+ * }}}
+ *
+ * @note Known issue: map() will NOT work when the applied function is a udf().
+ * It will compile and then throw a runtime error.
+ */
trait Mapper[X] {
- def map[G, OutputType[_,_]](u: ThisType[T, X] => OutputType[T,G])
- (implicit
- ev: OutputType[T,G] <:< AbstractTypedColumn[T, G]
+
+ def map[G, OutputType[_, _]](
+ u: ThisType[T, X] => OutputType[T, G]
+ )(implicit
+ ev: OutputType[T, G] <:< AbstractTypedColumn[T, G]
): OutputType[T, Option[G]] = {
- u(self.asInstanceOf[ThisType[T, X]]).asInstanceOf[OutputType[T, Option[G]]]
+ u(self.asInstanceOf[ThisType[T, X]])
+ .asInstanceOf[OutputType[T, Option[G]]]
}
}
- /** Makes it easier to work with Optional columns. It returns an instance of `Mapper[X]`
- * where `X` is type of the unwrapped Optional. E.g., in the case of `Option[Long]`,
- * `X` is of type Long.
- *
- * {{{
- * val x: TypedColumn[Option[Int]] = _
- * x.opt.map(_*2)
- * }}}
- * */
- def opt[X](implicit x: U <:< Option[X]): Mapper[X] = new Mapper[X] {}
+ /**
+ * Makes it easier to work with Optional columns. It returns an instance of `Mapper[X]`
+ * where `X` is type of the unwrapped Optional. E.g., in the case of `Option[Long]`,
+ * `X` is of type Long.
+ *
+ * {{{
+ * val x: TypedColumn[Option[Int]] = _
+ * x.opt.map(_*2)
+ * }}}
+ */
+ def opt[X](
+ implicit
+ x: U <:< Option[X]
+ ): Mapper[X] = new Mapper[X] {}
/** Fall back to an untyped Column */
def untyped: Column = new Column(expr)
- private def equalsTo[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = typed {
+ private def equalsTo[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] = typed {
if (uencoder.nullable) EqualNullSafe(self.expr, other.expr)
else EqualTo(self.expr, other.expr)
}
@@ -120,773 +150,1125 @@ abstract class AbstractTypedColumn[T, U]
/** Creates a typed column of either TypedColumn or TypedAggregate. */
def lit[U1: TypedEncoder](c: U1): ThisType[T, U1]
- /** Equality test.
- * {{{
- * df.filter( df.col('a) === 1 )
- * }}}
- *
- * apache/spark
- */
+ /**
+ * Equality test.
+ * {{{
+ * df.filter( df.col('a) === 1 )
+ * }}}
+ *
+ * apache/spark
+ */
def ===(u: U): ThisType[T, Boolean] =
equalsTo(lit(u))
- /** Equality test.
- * {{{
- * df.filter( df.col('a) === df.col('b) )
- * }}}
- *
- * apache/spark
- */
- def ===[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Equality test.
+ * {{{
+ * df.filter( df.col('a) === df.col('b) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def ===[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
equalsTo(other)
- /** Inequality test.
- *
- * {{{
- * df.filter(df.col('a) =!= df.col('b))
- * }}}
- *
- * apache/spark
- */
- def =!=[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Inequality test.
+ *
+ * {{{
+ * df.filter(df.col('a) =!= df.col('b))
+ * }}}
+ *
+ * apache/spark
+ */
+ def =!=[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(Not(equalsTo(other).expr))
- /** Inequality test.
- *
- * {{{
- * df.filter(df.col('a) =!= "a")
- * }}}
- *
- * apache/spark
- */
+ /**
+ * Inequality test.
+ *
+ * {{{
+ * df.filter(df.col('a) =!= "a")
+ * }}}
+ *
+ * apache/spark
+ */
def =!=(u: U): ThisType[T, Boolean] = typed(Not(equalsTo(lit(u)).expr))
- /** True if the current expression is an Option and it's None.
- *
- * apache/spark
- */
- def isNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] =
+ /**
+ * True if the current expression is an Option and it's None.
+ *
+ * apache/spark
+ */
+ def isNone(
+ implicit
+ i0: U <:< Option[_]
+ ): ThisType[T, Boolean] =
typed(IsNull(expr))
- /** True if the current expression is an Option and it's not None.
- *
- * apache/spark
- */
- def isNotNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] =
+ /**
+ * True if the current expression is an Option and it's not None.
+ *
+ * apache/spark
+ */
+ def isNotNone(
+ implicit
+ i0: U <:< Option[_]
+ ): ThisType[T, Boolean] =
typed(IsNotNull(expr))
- /** True if the current expression is a fractional number and is not NaN.
- *
- * apache/spark
- */
- def isNaN(implicit n: CatalystNaN[U]): ThisType[T, Boolean] =
+ /**
+ * True if the current expression is a fractional number and is not NaN.
+ *
+ * apache/spark
+ */
+ def isNaN(
+ implicit
+ n: CatalystNaN[U]
+ ): ThisType[T, Boolean] =
typed(self.untyped.isNaN)
/**
- * True if the value for this optional column `exists` as expected
- * (see `Option.exists`).
- *
- * {{{
- * df.col('opt).isSome(_ === someOtherCol)
- * }}}
- */
- def isSome[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, false)
+ * True if the value for this optional column `exists` as expected
+ * (see `Option.exists`).
+ *
+ * {{{
+ * df.col('opt).isSome(_ === someOtherCol)
+ * }}}
+ */
+ def isSome[V](
+ exists: ThisType[T, V] => ThisType[T, Boolean]
+ )(implicit
+ i0: U <:< Option[V]
+ ): ThisType[T, Boolean] = someOr[V](exists, false)
/**
- * True if the value for this optional column `exists` as expected,
- * or is `None`. (see `Option.forall`).
- *
- * {{{
- * df.col('opt).isSomeOrNone(_ === someOtherCol)
- * }}}
- */
- def isSomeOrNone[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, true)
-
- private def someOr[V](exists: ThisType[T, V] => ThisType[T, Boolean], default: Boolean)(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = {
+ * True if the value for this optional column `exists` as expected,
+ * or is `None`. (see `Option.forall`).
+ *
+ * {{{
+ * df.col('opt).isSomeOrNone(_ === someOtherCol)
+ * }}}
+ */
+ def isSomeOrNone[V](
+ exists: ThisType[T, V] => ThisType[T, Boolean]
+ )(implicit
+ i0: U <:< Option[V]
+ ): ThisType[T, Boolean] = someOr[V](exists, true)
+
+ private def someOr[V](
+ exists: ThisType[T, V] => ThisType[T, Boolean],
+ default: Boolean
+ )(implicit
+ i0: U <:< Option[V]
+ ): ThisType[T, Boolean] = {
val defaultExpr = if (default) Literal.TrueLiteral else Literal.FalseLiteral
typed(Coalesce(Seq(opt(i0).map(exists).expr, defaultExpr)))
}
- /** Convert an Optional column by providing a default value.
- *
- * {{{
- * df(df('opt).getOrElse(df('defaultValue)))
- * }}}
- */
- def getOrElse[TT, W, Out](default: ThisType[TT, Out])(implicit i0: U =:= Option[Out], i1: With.Aux[T, TT, W]): ThisType[W, Out] =
+ /**
+ * Convert an Optional column by providing a default value.
+ *
+ * {{{
+ * df(df('opt).getOrElse(df('defaultValue)))
+ * }}}
+ */
+ def getOrElse[TT, W, Out](
+ default: ThisType[TT, Out]
+ )(implicit
+ i0: U =:= Option[Out],
+ i1: With.Aux[T, TT, W]
+ ): ThisType[W, Out] =
typed(Coalesce(Seq(expr, default.expr)))(default.uencoder)
- /** Convert an Optional column by providing a default value.
- *
- * {{{
- * df( df('opt).getOrElse(defaultConstant) )
- * }}}
- */
- def getOrElse[Out: TypedEncoder](default: Out)(implicit i0: U =:= Option[Out]): ThisType[T, Out] =
+ /**
+ * Convert an Optional column by providing a default value.
+ *
+ * {{{
+ * df( df('opt).getOrElse(defaultConstant) )
+ * }}}
+ */
+ def getOrElse[Out: TypedEncoder](
+ default: Out
+ )(implicit
+ i0: U =:= Option[Out]
+ ): ThisType[T, Out] =
getOrElse(lit[Out](default))
- /** Sum of this expression and another expression.
- *
- * {{{
- * // The following selects the sum of a person's height and weight.
- * people.select( people.col('height) plus people.col('weight) )
- * }}}
- *
- * apache/spark
- */
- def plus[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Sum of this expression and another expression.
+ *
+ * {{{
+ * // The following selects the sum of a person's height and weight.
+ * people.select( people.col('height) plus people.col('weight) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def plus[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystNumeric[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
typed(self.untyped.plus(other.untyped))
- /** Sum of this expression and another expression.
- * {{{
- * // The following selects the sum of a person's height and weight.
- * people.select( people.col('height) + people.col('weight) )
- * }}}
- *
- * apache/spark
- */
- def +[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Sum of this expression and another expression.
+ * {{{
+ * // The following selects the sum of a person's height and weight.
+ * people.select( people.col('height) + people.col('weight) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def +[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystNumeric[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
plus(other)
- /** Sum of this expression (column) with a constant.
- * {{{
- * // The following selects the sum of a person's height and weight.
- * people.select( people('height) + 2 )
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def +(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] =
+ /**
+ * Sum of this expression (column) with a constant.
+ * {{{
+ * // The following selects the sum of a person's height and weight.
+ * people.select( people('height) + 2 )
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def +(
+ u: U
+ )(implicit
+ n: CatalystNumeric[U]
+ ): ThisType[T, U] =
typed(self.untyped.plus(u))
/**
- * Inversion of boolean expression, i.e. NOT.
- * {{{
- * // Select rows that are not active (isActive === false)
- * df.filter( !df('isActive) )
- * }}}
- *
- * apache/spark
- */
- def unary_!(implicit i0: U <:< Boolean): ThisType[T, Boolean] =
+ * Inversion of boolean expression, i.e. NOT.
+ * {{{
+ * // Select rows that are not active (isActive === false)
+ * df.filter( !df('isActive) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def unary_!(
+ implicit
+ i0: U <:< Boolean
+ ): ThisType[T, Boolean] =
typed(!untyped)
- /** Unary minus, i.e. negate the expression.
- * {{{
- * // Select the amount column and negates all values.
- * df.select( -df('amount) )
- * }}}
- *
- * apache/spark
- */
- def unary_-(implicit n: CatalystNumeric[U]): ThisType[T, U] =
+ /**
+ * Unary minus, i.e. negate the expression.
+ * {{{
+ * // Select the amount column and negates all values.
+ * df.select( -df('amount) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def unary_-(
+ implicit
+ n: CatalystNumeric[U]
+ ): ThisType[T, U] =
typed(-self.untyped)
- /** Subtraction. Subtract the other expression from this expression.
- * {{{
- * // The following selects the difference between people's height and their weight.
- * people.select( people.col('height) minus people.col('weight) )
- * }}}
- *
- * apache/spark
- */
- def minus[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Subtraction. Subtract the other expression from this expression.
+ * {{{
+ * // The following selects the difference between people's height and their weight.
+ * people.select( people.col('height) minus people.col('weight) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def minus[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystNumeric[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
typed(self.untyped.minus(other.untyped))
- /** Subtraction. Subtract the other expression from this expression.
- * {{{
- * // The following selects the difference between people's height and their weight.
- * people.select( people.col('height) - people.col('weight) )
- * }}}
- *
- * apache/spark
- */
- def -[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Subtraction. Subtract the other expression from this expression.
+ * {{{
+ * // The following selects the difference between people's height and their weight.
+ * people.select( people.col('height) - people.col('weight) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def -[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystNumeric[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
minus(other)
- /** Subtraction. Subtract the other expression from this expression.
- * {{{
- * // The following selects the difference between people's height and their weight.
- * people.select( people('height) - 1 )
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def -(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] =
+ /**
+ * Subtraction. Subtract the other expression from this expression.
+ * {{{
+ * // The following selects the difference between people's height and their weight.
+ * people.select( people('height) - 1 )
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def -(
+ u: U
+ )(implicit
+ n: CatalystNumeric[U]
+ ): ThisType[T, U] =
typed(self.untyped.minus(u))
- /** Multiplication of this expression and another expression.
- * {{{
- * // The following multiplies a person's height by their weight.
- * people.select( people.col('height) multiply people.col('weight) )
- * }}}
- *
- * apache/spark
- */
- def multiply[TT, W]
- (other: ThisType[TT, U])
- (implicit
+ /**
+ * Multiplication of this expression and another expression.
+ * {{{
+ * // The following multiplies a person's height by their weight.
+ * people.select( people.col('height) multiply people.col('weight) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def multiply[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
n: CatalystNumeric[U],
w: With.Aux[T, TT, W],
t: ClassTag[U]
): ThisType[W, U] = typed {
- if (t.runtimeClass == BigDecimal(0).getClass) {
- // That's apparently the only way to get sound multiplication.
- // See https://issues.apache.org/jira/browse/SPARK-22036
- val dt = DecimalType(20, 14)
- self.untyped.cast(dt).multiply(other.untyped.cast(dt))
- } else {
- self.untyped.multiply(other.untyped)
- }
+ if (t.runtimeClass == BigDecimal(0).getClass) {
+ // That's apparently the only way to get sound multiplication.
+ // See https://issues.apache.org/jira/browse/SPARK-22036
+ val dt = DecimalType(20, 14)
+ self.untyped.cast(dt).multiply(other.untyped.cast(dt))
+ } else {
+ self.untyped.multiply(other.untyped)
}
+ }
- /** Multiplication of this expression and another expression.
- * {{{
- * // The following multiplies a person's height by their weight.
- * people.select( people.col('height) * people.col('weight) )
- * }}}
- *
- * apache/spark
- */
- def *[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W], t: ClassTag[U]): ThisType[W, U] =
+ /**
+ * Multiplication of this expression and another expression.
+ * {{{
+ * // The following multiplies a person's height by their weight.
+ * people.select( people.col('height) * people.col('weight) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def *[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystNumeric[U],
+ w: With.Aux[T, TT, W],
+ t: ClassTag[U]
+ ): ThisType[W, U] =
multiply(other)
- /** Multiplication of this expression a constant.
- * {{{
- * // The following multiplies a person's height by their weight.
- * people.select( people.col('height) * people.col('weight) )
- * }}}
- *
- * apache/spark
- */
- def *(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] =
+ /**
+ * Multiplication of this expression a constant.
+ * {{{
+ * // The following multiplies a person's height by their weight.
+ * people.select( people.col('height) * people.col('weight) )
+ * }}}
+ *
+ * apache/spark
+ */
+ def *(
+ u: U
+ )(implicit
+ n: CatalystNumeric[U]
+ ): ThisType[T, U] =
typed(self.untyped.multiply(u))
- /** Modulo (a.k.a. remainder) expression.
- *
- * apache/spark
- */
- def mod[Out: TypedEncoder, TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, Out] =
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ *
+ * apache/spark
+ */
+ def mod[Out: TypedEncoder, TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystNumeric[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Out] =
typed(self.untyped.mod(other.untyped))
- /** Modulo (a.k.a. remainder) expression.
- *
- * apache/spark
- */
- def %[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ *
+ * apache/spark
+ */
+ def %[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystNumeric[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
mod(other)
- /** Modulo (a.k.a. remainder) expression.
- *
- * apache/spark
- */
- def %(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] =
+ /**
+ * Modulo (a.k.a. remainder) expression.
+ *
+ * apache/spark
+ */
+ def %(
+ u: U
+ )(implicit
+ n: CatalystNumeric[U]
+ ): ThisType[T, U] =
typed(self.untyped.mod(u))
- /** Division this expression by another expression.
- * {{{
- * // The following divides a person's height by their weight.
- * people.select( people('height) / people('weight) )
- * }}}
- *
- * @param other another column of the same type
- * apache/spark
- */
- def divide[Out: TypedEncoder, TT, W](other: ThisType[TT, U])(implicit n: CatalystDivisible[U, Out], w: With.Aux[T, TT, W]): ThisType[W, Out] =
+ /**
+ * Division this expression by another expression.
+ * {{{
+ * // The following divides a person's height by their weight.
+ * people.select( people('height) / people('weight) )
+ * }}}
+ *
+ * @param other another column of the same type
+ * apache/spark
+ */
+ def divide[Out: TypedEncoder, TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystDivisible[U, Out],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Out] =
typed(self.untyped.divide(other.untyped))
- /** Division this expression by another expression.
- * {{{
- * // The following divides a person's height by their weight.
- * people.select( people('height) / people('weight) )
- * }}}
- *
- * @param other another column of the same type
- * apache/spark
- */
- def /[Out, TT, W](other: ThisType[TT, U])(implicit n: CatalystDivisible[U, Out], e: TypedEncoder[Out], w: With.Aux[T, TT, W]): ThisType[W, Out] =
+ /**
+ * Division this expression by another expression.
+ * {{{
+ * // The following divides a person's height by their weight.
+ * people.select( people('height) / people('weight) )
+ * }}}
+ *
+ * @param other another column of the same type
+ * apache/spark
+ */
+ def /[Out, TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystDivisible[U, Out],
+ e: TypedEncoder[Out],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Out] =
divide(other)
- /** Division this expression by another expression.
- * {{{
- * // The following divides a person's height by their weight.
- * people.select( people('height) / 2 )
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def /(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, Double] =
+ /**
+ * Division this expression by another expression.
+ * {{{
+ * // The following divides a person's height by their weight.
+ * people.select( people('height) / 2 )
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def /(
+ u: U
+ )(implicit
+ n: CatalystNumeric[U]
+ ): ThisType[T, Double] =
typed(self.untyped.divide(u))
- /** Returns a descending ordering used in sorting
- *
- * apache/spark
- */
- def desc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] =
+ /**
+ * Returns a descending ordering used in sorting
+ *
+ * apache/spark
+ */
+ def desc(
+ implicit
+ catalystOrdered: CatalystOrdered[U]
+ ): SortedTypedColumn[T, U] =
new SortedTypedColumn[T, U](untyped.desc)
- /** Returns an ascending ordering used in sorting
- *
- * apache/spark
- */
- def asc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] =
+ /**
+ * Returns an ascending ordering used in sorting
+ *
+ * apache/spark
+ */
+ def asc(
+ implicit
+ catalystOrdered: CatalystOrdered[U]
+ ): SortedTypedColumn[T, U] =
new SortedTypedColumn[T, U](untyped.asc)
- /** Bitwise AND this expression and another expression.
- * {{{
- * df.select(df.col('colA) bitwiseAND (df.col('colB)))
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def bitwiseAND(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] =
+ /**
+ * Bitwise AND this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) bitwiseAND (df.col('colB)))
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def bitwiseAND(
+ u: U
+ )(implicit
+ n: CatalystBitwise[U]
+ ): ThisType[T, U] =
typed(self.untyped.bitwiseAND(u))
- /** Bitwise AND this expression and another expression.
- * {{{
- * df.select(df.col('colA) bitwiseAND (df.col('colB)))
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def bitwiseAND[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Bitwise AND this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) bitwiseAND (df.col('colB)))
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def bitwiseAND[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystBitwise[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
typed(self.untyped.bitwiseAND(other.untyped))
- /** Bitwise AND this expression and another expression (of same type).
- * {{{
- * df.select(df.col('colA).cast[Int] & -1)
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def &(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] =
+ /**
+ * Bitwise AND this expression and another expression (of same type).
+ * {{{
+ * df.select(df.col('colA).cast[Int] & -1)
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def &(
+ u: U
+ )(implicit
+ n: CatalystBitwise[U]
+ ): ThisType[T, U] =
bitwiseAND(u)
- /** Bitwise AND this expression and another expression.
- * {{{
- * df.select(df.col('colA) & (df.col('colB)))
- * }}}
- *
- * @param other a constant of the same type
- * apache/spark
- */
- def &[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Bitwise AND this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) & (df.col('colB)))
+ * }}}
+ *
+ * @param other a constant of the same type
+ * apache/spark
+ */
+ def &[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystBitwise[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
bitwiseAND(other)
- /** Bitwise OR this expression and another expression.
- * {{{
- * df.select(df.col('colA) bitwiseOR (df.col('colB)))
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def bitwiseOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] =
+ /**
+ * Bitwise OR this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) bitwiseOR (df.col('colB)))
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def bitwiseOR(
+ u: U
+ )(implicit
+ n: CatalystBitwise[U]
+ ): ThisType[T, U] =
typed(self.untyped.bitwiseOR(u))
- /** Bitwise OR this expression and another expression.
- * {{{
- * df.select(df.col('colA) bitwiseOR (df.col('colB)))
- * }}}
- *
- * @param other a constant of the same type
- * apache/spark
- */
- def bitwiseOR[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Bitwise OR this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) bitwiseOR (df.col('colB)))
+ * }}}
+ *
+ * @param other a constant of the same type
+ * apache/spark
+ */
+ def bitwiseOR[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystBitwise[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
typed(self.untyped.bitwiseOR(other.untyped))
- /** Bitwise OR this expression and another expression (of same type).
- * {{{
- * df.select(df.col('colA).cast[Long] | 1L)
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def |(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] =
+ /**
+ * Bitwise OR this expression and another expression (of same type).
+ * {{{
+ * df.select(df.col('colA).cast[Long] | 1L)
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def |(
+ u: U
+ )(implicit
+ n: CatalystBitwise[U]
+ ): ThisType[T, U] =
bitwiseOR(u)
- /** Bitwise OR this expression and another expression.
- * {{{
- * df.select(df.col('colA) | (df.col('colB)))
- * }}}
- *
- * @param other a constant of the same type
- * apache/spark
- */
- def |[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Bitwise OR this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) | (df.col('colB)))
+ * }}}
+ *
+ * @param other a constant of the same type
+ * apache/spark
+ */
+ def |[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystBitwise[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
bitwiseOR(other)
- /** Bitwise XOR this expression and another expression.
- * {{{
- * df.select(df.col('colA) bitwiseXOR (df.col('colB)))
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def bitwiseXOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] =
+ /**
+ * Bitwise XOR this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) bitwiseXOR (df.col('colB)))
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def bitwiseXOR(
+ u: U
+ )(implicit
+ n: CatalystBitwise[U]
+ ): ThisType[T, U] =
typed(self.untyped.bitwiseXOR(u))
- /** Bitwise XOR this expression and another expression.
- * {{{
- * df.select(df.col('colA) bitwiseXOR (df.col('colB)))
- * }}}
- *
- * @param other a constant of the same type
- * apache/spark
- */
- def bitwiseXOR[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Bitwise XOR this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) bitwiseXOR (df.col('colB)))
+ * }}}
+ *
+ * @param other a constant of the same type
+ * apache/spark
+ */
+ def bitwiseXOR[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystBitwise[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
typed(self.untyped.bitwiseXOR(other.untyped))
- /** Bitwise XOR this expression and another expression (of same type).
- * {{{
- * df.select(df.col('colA).cast[Long] ^ 1L)
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def ^(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] =
+ /**
+ * Bitwise XOR this expression and another expression (of same type).
+ * {{{
+ * df.select(df.col('colA).cast[Long] ^ 1L)
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def ^(
+ u: U
+ )(implicit
+ n: CatalystBitwise[U]
+ ): ThisType[T, U] =
bitwiseXOR(u)
- /** Bitwise XOR this expression and another expression.
- * {{{
- * df.select(df.col('colA) ^ (df.col('colB)))
- * }}}
- *
- * @param other a constant of the same type
- * apache/spark
- */
- def ^[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] =
+ /**
+ * Bitwise XOR this expression and another expression.
+ * {{{
+ * df.select(df.col('colA) ^ (df.col('colB)))
+ * }}}
+ *
+ * @param other a constant of the same type
+ * apache/spark
+ */
+ def ^[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ n: CatalystBitwise[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, U] =
bitwiseXOR(other)
- /** Casts the column to a different type.
- * {{{
- * df.select(df('a).cast[Int])
- * }}}
- */
- def cast[A: TypedEncoder](implicit c: CatalystCast[U, A]): ThisType[T, A] =
+ /**
+ * Casts the column to a different type.
+ * {{{
+ * df.select(df('a).cast[Int])
+ * }}}
+ */
+ def cast[A: TypedEncoder](
+ implicit
+ c: CatalystCast[U, A]
+ ): ThisType[T, A] =
typed(self.untyped.cast(TypedEncoder[A].catalystRepr))
/**
- * An expression that returns a substring
- * {{{
- * df.select(df('a).substr(0, 5))
- * }}}
- *
- * @param startPos starting position
- * @param len length of the substring
- */
- def substr(startPos: Int, len: Int)(implicit ev: U =:= String): ThisType[T, String] =
+ * An expression that returns a substring
+ * {{{
+ * df.select(df('a).substr(0, 5))
+ * }}}
+ *
+ * @param startPos starting position
+ * @param len length of the substring
+ */
+ def substr(
+ startPos: Int,
+ len: Int
+ )(implicit
+ ev: U =:= String
+ ): ThisType[T, String] =
typed(self.untyped.substr(startPos, len))
/**
- * An expression that returns a substring
- * {{{
- * df.select(df('a).substr(df('b), df('c)))
- * }}}
- *
- * @param startPos expression for the starting position
- * @param len expression for the length of the substring
- */
- def substr[TT1, TT2, W1, W2](startPos: ThisType[TT1, Int], len: ThisType[TT2, Int])
- (implicit
- ev: U =:= String,
- w1: With.Aux[T, TT1, W1],
- w2: With.Aux[W1, TT2, W2]): ThisType[W2, String] =
+ * An expression that returns a substring
+ * {{{
+ * df.select(df('a).substr(df('b), df('c)))
+ * }}}
+ *
+ * @param startPos expression for the starting position
+ * @param len expression for the length of the substring
+ */
+ def substr[TT1, TT2, W1, W2](
+ startPos: ThisType[TT1, Int],
+ len: ThisType[TT2, Int]
+ )(implicit
+ ev: U =:= String,
+ w1: With.Aux[T, TT1, W1],
+ w2: With.Aux[W1, TT2, W2]
+ ): ThisType[W2, String] =
typed(self.untyped.substr(startPos.untyped, len.untyped))
- /** SQL like expression. Returns a boolean column based on a SQL LIKE match.
- * {{{
- * val ds = TypedDataset.create(X2("foo", "bar") :: Nil)
- * // true
- * ds.select(ds('a).like("foo"))
- *
- * // Selected column has value "bar"
- * ds.select(when(ds('a).like("f"), ds('a)).otherwise(ds('b))
- * }}}
- * apache/spark
- */
- def like(literal: String)(implicit ev: U =:= String): ThisType[T, Boolean] =
+ /**
+ * SQL like expression. Returns a boolean column based on a SQL LIKE match.
+ * {{{
+ * val ds = TypedDataset.create(X2("foo", "bar") :: Nil)
+ * // true
+ * ds.select(ds('a).like("foo"))
+ *
+ * // Selected column has value "bar"
+ * ds.select(when(ds('a).like("f"), ds('a)).otherwise(ds('b))
+ * }}}
+ * apache/spark
+ */
+ def like(
+ literal: String
+ )(implicit
+ ev: U =:= String
+ ): ThisType[T, Boolean] =
typed(self.untyped.like(literal))
- /** SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match.
- * {{{
- * val ds = TypedDataset.create(X1("foo") :: Nil)
- * // true
- * ds.select(ds('a).rlike("foo"))
- *
- * // true
- * ds.select(ds('a).rlike(".*))
- * }}}
- * apache/spark
- */
- def rlike(literal: String)(implicit ev: U =:= String): ThisType[T, Boolean] =
+ /**
+ * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match.
+ * {{{
+ * val ds = TypedDataset.create(X1("foo") :: Nil)
+ * // true
+ * ds.select(ds('a).rlike("foo"))
+ *
+ * // true
+ * ds.select(ds('a).rlike(".*))
+ * }}}
+ * apache/spark
+ */
+ def rlike(
+ literal: String
+ )(implicit
+ ev: U =:= String
+ ): ThisType[T, Boolean] =
typed(self.untyped.rlike(literal))
- /** String contains another string literal.
- * {{{
- * df.filter ( df.col('a).contains("foo") )
- * }}}
- *
- * @param other a string that is being tested against.
- * apache/spark
- */
- def contains(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] =
+ /**
+ * String contains another string literal.
+ * {{{
+ * df.filter ( df.col('a).contains("foo") )
+ * }}}
+ *
+ * @param other a string that is being tested against.
+ * apache/spark
+ */
+ def contains(
+ other: String
+ )(implicit
+ ev: U =:= String
+ ): ThisType[T, Boolean] =
typed(self.untyped.contains(other))
- /** String contains.
- * {{{
- * df.filter ( df.col('a).contains(df.col('b) )
- * }}}
- *
- * @param other a column which values is used as a string that is being tested against.
- * apache/spark
- */
- def contains[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * String contains.
+ * {{{
+ * df.filter ( df.col('a).contains(df.col('b) )
+ * }}}
+ *
+ * @param other a column which values is used as a string that is being tested against.
+ * apache/spark
+ */
+ def contains[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ ev: U =:= String,
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped.contains(other.untyped))
- /** String starts with another string literal.
- * {{{
- * df.filter ( df.col('a).startsWith("foo")
- * }}}
- *
- * @param other a prefix that is being tested against.
- * apache/spark
- */
- def startsWith(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] =
+ /**
+ * String starts with another string literal.
+ * {{{
+ * df.filter ( df.col('a).startsWith("foo")
+ * }}}
+ *
+ * @param other a prefix that is being tested against.
+ * apache/spark
+ */
+ def startsWith(
+ other: String
+ )(implicit
+ ev: U =:= String
+ ): ThisType[T, Boolean] =
typed(self.untyped.startsWith(other))
- /** String starts with.
- * {{{
- * df.filter ( df.col('a).startsWith(df.col('b))
- * }}}
- *
- * @param other a column which values is used as a prefix that is being tested against.
- * apache/spark
- */
- def startsWith[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * String starts with.
+ * {{{
+ * df.filter ( df.col('a).startsWith(df.col('b))
+ * }}}
+ *
+ * @param other a column which values is used as a prefix that is being tested against.
+ * apache/spark
+ */
+ def startsWith[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ ev: U =:= String,
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped.startsWith(other.untyped))
- /** String ends with another string literal.
- * {{{
- * df.filter ( df.col('a).endsWith("foo")
- * }}}
- *
- * @param other a suffix that is being tested against.
- * apache/spark
- */
- def endsWith(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] =
+ /**
+ * String ends with another string literal.
+ * {{{
+ * df.filter ( df.col('a).endsWith("foo")
+ * }}}
+ *
+ * @param other a suffix that is being tested against.
+ * apache/spark
+ */
+ def endsWith(
+ other: String
+ )(implicit
+ ev: U =:= String
+ ): ThisType[T, Boolean] =
typed(self.untyped.endsWith(other))
- /** String ends with.
- * {{{
- * df.filter ( df.col('a).endsWith(df.col('b))
- * }}}
- *
- * @param other a column which values is used as a suffix that is being tested against.
- * apache/spark
- */
- def endsWith[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * String ends with.
+ * {{{
+ * df.filter ( df.col('a).endsWith(df.col('b))
+ * }}}
+ *
+ * @param other a column which values is used as a suffix that is being tested against.
+ * apache/spark
+ */
+ def endsWith[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ ev: U =:= String,
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped.endsWith(other.untyped))
- /** Boolean AND.
- * {{{
- * df.filter ( (df.col('a) === 1).and(df.col('b) > 5) )
- * }}}
- */
- def and[TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Boolean AND.
+ * {{{
+ * df.filter ( (df.col('a) === 1).and(df.col('b) > 5) )
+ * }}}
+ */
+ def and[TT, W](
+ other: ThisType[TT, Boolean]
+ )(implicit
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped.and(other.untyped))
- /** Boolean AND.
- * {{{
- * df.filter ( df.col('a) === 1 && df.col('b) > 5)
- * }}}
- */
- def && [TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Boolean AND.
+ * {{{
+ * df.filter ( df.col('a) === 1 && df.col('b) > 5)
+ * }}}
+ */
+ def &&[TT, W](
+ other: ThisType[TT, Boolean]
+ )(implicit
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
and(other)
- /** Boolean OR.
- * {{{
- * df.filter ( (df.col('a) === 1).or(df.col('b) > 5) )
- * }}}
- */
- def or[TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Boolean OR.
+ * {{{
+ * df.filter ( (df.col('a) === 1).or(df.col('b) > 5) )
+ * }}}
+ */
+ def or[TT, W](
+ other: ThisType[TT, Boolean]
+ )(implicit
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped.or(other.untyped))
- /** Boolean OR.
- * {{{
- * df.filter ( df.col('a) === 1 || df.col('b) > 5)
- * }}}
- */
- def || [TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Boolean OR.
+ * {{{
+ * df.filter ( df.col('a) === 1 || df.col('b) > 5)
+ * }}}
+ */
+ def ||[TT, W](
+ other: ThisType[TT, Boolean]
+ )(implicit
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
or(other)
- /** Less than.
- *
- * {{{
- * // The following selects people younger than the maxAge column.
- * df.select(df('age) < df('maxAge) )
- * }}}
- *
- * @param other another column of the same type
- * apache/spark
- */
- def <[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Less than.
+ *
+ * {{{
+ * // The following selects people younger than the maxAge column.
+ * df.select(df('age) < df('maxAge) )
+ * }}}
+ *
+ * @param other another column of the same type
+ * apache/spark
+ */
+ def <[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ i0: CatalystOrdered[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped < other.untyped)
- /** Less than or equal to.
- *
- * {{{
- * // The following selects people younger or equal than the maxAge column.
- * df.select(df('age) <= df('maxAge)
- * }}}
- *
- * @param other another column of the same type
- * apache/spark
- */
- def <=[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Less than or equal to.
+ *
+ * {{{
+ * // The following selects people younger or equal than the maxAge column.
+ * df.select(df('age) <= df('maxAge)
+ * }}}
+ *
+ * @param other another column of the same type
+ * apache/spark
+ */
+ def <=[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ i0: CatalystOrdered[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped <= other.untyped)
- /** Greater than.
- * {{{
- * // The following selects people older than the maxAge column.
- * df.select( df('age) > df('maxAge) )
- * }}}
- *
- * @param other another column of the same type
- * apache/spark
- */
- def >[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Greater than.
+ * {{{
+ * // The following selects people older than the maxAge column.
+ * df.select( df('age) > df('maxAge) )
+ * }}}
+ *
+ * @param other another column of the same type
+ * apache/spark
+ */
+ def >[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ i0: CatalystOrdered[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped > other.untyped)
- /** Greater than or equal.
- * {{{
- * // The following selects people older or equal than the maxAge column.
- * df.select( df('age) >= df('maxAge) )
- * }}}
- *
- * @param other another column of the same type
- * apache/spark
- */
- def >=[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] =
+ /**
+ * Greater than or equal.
+ * {{{
+ * // The following selects people older or equal than the maxAge column.
+ * df.select( df('age) >= df('maxAge) )
+ * }}}
+ *
+ * @param other another column of the same type
+ * apache/spark
+ */
+ def >=[TT, W](
+ other: ThisType[TT, U]
+ )(implicit
+ i0: CatalystOrdered[U],
+ w: With.Aux[T, TT, W]
+ ): ThisType[W, Boolean] =
typed(self.untyped >= other.untyped)
- /** Less than.
- * {{{
- * // The following selects people younger than 21.
- * df.select( df('age) < 21 )
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def <(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] =
+ /**
+ * Less than.
+ * {{{
+ * // The following selects people younger than 21.
+ * df.select( df('age) < 21 )
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def <(
+ u: U
+ )(implicit
+ i0: CatalystOrdered[U]
+ ): ThisType[T, Boolean] =
typed(self.untyped < lit(u)(self.uencoder).untyped)
- /** Less than or equal to.
- * {{{
- * // The following selects people younger than 22.
- * df.select( df('age) <= 2 )
- * }}}
- *
- * @param u a constant of the same type
- * apache/spark
- */
- def <=(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] =
+ /**
+ * Less than or equal to.
+ * {{{
+ * // The following selects people younger than 22.
+ * df.select( df('age) <= 2 )
+ * }}}
+ *
+ * @param u a constant of the same type
+ * apache/spark
+ */
+ def <=(
+ u: U
+ )(implicit
+ i0: CatalystOrdered[U]
+ ): ThisType[T, Boolean] =
typed(self.untyped <= lit(u)(self.uencoder).untyped)
- /** Greater than.
- * {{{
- * // The following selects people older than 21.
- * df.select( df('age) > 21 )
- * }}}
- *
- * @param u another column of the same type
- * apache/spark
- */
- def >(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] =
+ /**
+ * Greater than.
+ * {{{
+ * // The following selects people older than 21.
+ * df.select( df('age) > 21 )
+ * }}}
+ *
+ * @param u another column of the same type
+ * apache/spark
+ */
+ def >(
+ u: U
+ )(implicit
+ i0: CatalystOrdered[U]
+ ): ThisType[T, Boolean] =
typed(self.untyped > lit(u)(self.uencoder).untyped)
- /** Greater than or equal.
- * {{{
- * // The following selects people older than 20.
- * df.select( df('age) >= 21 )
- * }}}
- *
- * @param u another column of the same type
- * apache/spark
- */
- def >=(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] =
+ /**
+ * Greater than or equal.
+ * {{{
+ * // The following selects people older than 20.
+ * df.select( df('age) >= 21 )
+ * }}}
+ *
+ * @param u another column of the same type
+ * apache/spark
+ */
+ def >=(
+ u: U
+ )(implicit
+ i0: CatalystOrdered[U]
+ ): ThisType[T, Boolean] =
typed(self.untyped >= lit(u)(self.uencoder).untyped)
/**
- * Returns true if the value of this column is contained in of the arguments.
- * {{{
- * // The following selects people with age 15, 20, or 30.
- * df.select( df('age).isin(15, 20, 30) )
- * }}}
- *
- * @param values are constants of the same type
- * apache/spark
- */
- def isin(values: U*)(implicit e: CatalystIsin[U]): ThisType[T, Boolean] =
- typed(self.untyped.isin(values:_*))
-
- /**
- * True if the current column is between the lower bound and upper bound, inclusive.
- *
- * @param lowerBound a constant of the same type
- * @param upperBound a constant of the same type
- * apache/spark
- */
- def between(lowerBound: U, upperBound: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] =
- typed(self.untyped.between(lit(lowerBound)(self.uencoder).untyped, lit(upperBound)(self.uencoder).untyped))
-
- /**
- * True if the current column is between the lower bound and upper bound, inclusive.
- *
- * @param lowerBound another column of the same type
- * @param upperBound another column of the same type
- * apache/spark
- */
- def between[TT1, TT2, W1, W2](lowerBound: ThisType[TT1, U], upperBound: ThisType[TT2, U])
- (implicit
+ * Returns true if the value of this column is contained in of the arguments.
+ * {{{
+ * // The following selects people with age 15, 20, or 30.
+ * df.select( df('age).isin(15, 20, 30) )
+ * }}}
+ *
+ * @param values are constants of the same type
+ * apache/spark
+ */
+ def isin(
+ values: U*
+ )(implicit
+ e: CatalystIsin[U]
+ ): ThisType[T, Boolean] =
+ typed(self.untyped.isin(values: _*))
+
+ /**
+ * True if the current column is between the lower bound and upper bound, inclusive.
+ *
+ * @param lowerBound a constant of the same type
+ * @param upperBound a constant of the same type
+ * apache/spark
+ */
+ def between(
+ lowerBound: U,
+ upperBound: U
+ )(implicit
+ i0: CatalystOrdered[U]
+ ): ThisType[T, Boolean] =
+ typed(
+ self.untyped.between(
+ lit(lowerBound)(self.uencoder).untyped,
+ lit(upperBound)(self.uencoder).untyped
+ )
+ )
+
+ /**
+ * True if the current column is between the lower bound and upper bound, inclusive.
+ *
+ * @param lowerBound another column of the same type
+ * @param upperBound another column of the same type
+ * apache/spark
+ */
+ def between[TT1, TT2, W1, W2](
+ lowerBound: ThisType[TT1, U],
+ upperBound: ThisType[TT2, U]
+ )(implicit
i0: CatalystOrdered[U],
w0: With.Aux[T, TT1, W1],
w1: With.Aux[TT2, W1, W2]
): ThisType[W2, Boolean] =
- typed(self.untyped.between(lowerBound.untyped, upperBound.untyped))
+ typed(self.untyped.between(lowerBound.untyped, upperBound.untyped))
/**
- * Returns a nested column matching the field `symbol`.
- *
- * @param symbol the field symbol
- * @tparam V the type of the nested field
- */
- def field[V](symbol: Witness.Lt[Symbol])(implicit
+ * Returns a nested column matching the field `symbol`.
+ *
+ * @param symbol the field symbol
+ * @tparam V the type of the nested field
+ */
+ def field[V](
+ symbol: Witness.Lt[Symbol]
+ )(implicit
i0: TypedColumn.Exists[U, symbol.T, V],
i1: TypedEncoder[V]
- ): ThisType[T, V] =
+ ): ThisType[T, V] =
typed(self.untyped.getField(symbol.value.name))
}
-
-sealed class SortedTypedColumn[T, U](val expr: Expression)(
- implicit
- val uencoder: TypedEncoder[U]
-) extends UntypedExpression[T] {
-
- def this(column: Column)(implicit e: TypedEncoder[U]) = {
+sealed class SortedTypedColumn[T, U](
+ val expr: Expression
+ )(implicit
+ val uencoder: TypedEncoder[U])
+ extends UntypedExpression[T] {
+
+ def this(
+ column: Column
+ )(implicit
+ e: TypedEncoder[U]
+ ) = {
this(FramelessInternals.expr(column))
}
@@ -894,16 +1276,24 @@ sealed class SortedTypedColumn[T, U](val expr: Expression)(
}
object SortedTypedColumn {
- implicit def defaultAscending[T, U : CatalystOrdered](typedColumn: TypedColumn[T, U]): SortedTypedColumn[T, U] =
+
+ implicit def defaultAscending[T, U: CatalystOrdered](
+ typedColumn: TypedColumn[T, U]
+ ): SortedTypedColumn[T, U] =
new SortedTypedColumn[T, U](typedColumn.untyped.asc)(typedColumn.uencoder)
- object defaultAscendingPoly extends Poly1 {
- implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c))
- implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity)
- }
+ object defaultAscendingPoly extends Poly1 {
+
+ implicit def caseTypedColumn[T, U: CatalystOrdered] =
+ at[TypedColumn[T, U]](c => defaultAscending(c))
+
+ implicit def caseTypeSortedColumn[T, U] =
+ at[SortedTypedColumn[T, U]](identity)
+ }
}
object TypedColumn {
+
/** Evidence that type `T` has column `K` with type `V`. */
@implicitNotFound(msg = "No column ${K} of type ${V} in ${T}")
trait Exists[T, K, V]
@@ -912,37 +1302,46 @@ object TypedColumn {
trait ExistsMany[T, K <: HList, V]
object ExistsMany {
- implicit def deriveCons[T, KH, KT <: HList, V0, V1]
- (implicit
+
+ implicit def deriveCons[T, KH, KT <: HList, V0, V1](
+ implicit
head: Exists[T, KH, V0],
tail: ExistsMany[V0, KT, V1]
): ExistsMany[T, KH :: KT, V1] =
- new ExistsMany[T, KH :: KT, V1] {}
+ new ExistsMany[T, KH :: KT, V1] {}
- implicit def deriveHNil[T, K, V](implicit head: Exists[T, K, V]): ExistsMany[T, K :: HNil, V] =
+ implicit def deriveHNil[T, K, V](
+ implicit
+ head: Exists[T, K, V]
+ ): ExistsMany[T, K :: HNil, V] =
new ExistsMany[T, K :: HNil, V] {}
}
object Exists {
- def apply[T, V](column: Witness)(implicit e: Exists[T, column.T, V]): Exists[T, column.T, V] = e
- implicit def deriveRecord[T, H <: HList, K, V]
- (implicit
+ def apply[T, V](
+ column: Witness
+ )(implicit
+ e: Exists[T, column.T, V]
+ ): Exists[T, column.T, V] = e
+
+ implicit def deriveRecord[T, H <: HList, K, V](
+ implicit
i0: LabelledGeneric.Aux[T, H],
i1: Selector.Aux[H, K, V]
): Exists[T, K, V] = new Exists[T, K, V] {}
}
/**
- * {{{
- * import frameless.TypedColumn
- *
- * case class Foo(id: Int, bar: String)
- *
- * val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar }
- * val colid = TypedColumn[Foo, Int](_.id)
- * }}}
- */
+ * {{{
+ * import frameless.TypedColumn
+ *
+ * case class Foo(id: Int, bar: String)
+ *
+ * val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar }
+ * val colid = TypedColumn[Foo, Int](_.id)
+ * }}}
+ */
def apply[T, U](x: T => U): TypedColumn[T, U] =
macro TypedColumnMacroImpl.applyImpl[T, U]
diff --git a/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala b/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala
index 62fa2765d..7a1981b92 100644
--- a/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala
+++ b/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala
@@ -4,7 +4,10 @@ import scala.reflect.macros.whitebox
private[frameless] object TypedColumnMacroImpl {
- def applyImpl[T: c.WeakTypeTag, U: c.WeakTypeTag](c: whitebox.Context)(x: c.Tree): c.Expr[TypedColumn[T, U]] = {
+ def applyImpl[T: c.WeakTypeTag, U: c.WeakTypeTag](
+ c: whitebox.Context
+ )(x: c.Tree
+ ): c.Expr[TypedColumn[T, U]] = {
import c.universe._
val t = c.weakTypeOf[T]
@@ -13,7 +16,9 @@ private[frameless] object TypedColumnMacroImpl {
def buildExpression(path: List[String]): c.Expr[TypedColumn[T, U]] = {
val columnName = path.mkString(".")
- c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($columnName)).expr)")
+ c.Expr[TypedColumn[T, U]](
+ q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($columnName)).expr)"
+ )
}
def abort(msg: String) = c.abort(c.enclosingPosition, msg)
@@ -48,34 +53,39 @@ private[frameless] object TypedColumnMacroImpl {
}
x match {
- case fn: Function => fn.body match {
- case select: Select if select.name.isTermName =>
- val expectedRoot: Option[String] = fn.vparams match {
- case List(rt) if rt.rhs == EmptyTree =>
- Option.empty[String]
-
- case List(rt) =>
- Some(rt.toString)
+ case fn: Function =>
+ fn.body match {
+ case select: Select if select.name.isTermName =>
+ val expectedRoot: Option[String] = fn.vparams match {
+ case List(rt) if rt.rhs == EmptyTree =>
+ Option.empty[String]
+
+ case List(rt) =>
+ Some(rt.toString)
+
+ case u =>
+ abort(
+ s"Select expression must have a single parameter: ${u mkString ", "}"
+ )
+ }
- case u =>
- abort(s"Select expression must have a single parameter: ${u mkString ", "}")
- }
+ path(select, List.empty) match {
+ case root :: tail
+ if (expectedRoot.forall(_ == root) && check(t, tail)) => {
+ val colPath = tail.mkString(".")
- path(select, List.empty) match {
- case root :: tail if (
- expectedRoot.forall(_ == root) && check(t, tail)) => {
- val colPath = tail.mkString(".")
+ c.Expr[TypedColumn[T, U]](
+ q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($colPath)).expr)"
+ )
+ }
- c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($colPath)).expr)")
+ case _ =>
+ abort(s"Invalid select expression: $select")
}
- case _ =>
- abort(s"Invalid select expression: $select")
- }
-
- case t =>
- abort(s"Select expression expected: $t")
- }
+ case t =>
+ abort(s"Select expression expected: $t")
+ }
case _ =>
abort(s"Function expected: $x")
diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala
index add2170b2..28ad4fa5f 100644
--- a/dataset/src/main/scala/frameless/TypedDataset.scala
+++ b/dataset/src/main/scala/frameless/TypedDataset.scala
@@ -4,36 +4,58 @@ import java.util
import frameless.functions.CatalystExplodableCollection
import frameless.ops._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Column, DataFrame, Dataset, FramelessInternals, SparkSession}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal}
-import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint}
+import org.apache.spark.sql.{
+ Column,
+ DataFrame,
+ Dataset,
+ FramelessInternals,
+ SparkSession
+}
+import org.apache.spark.sql.catalyst.expressions.{
+ Attribute,
+ AttributeReference,
+ Literal
+}
+import org.apache.spark.sql.catalyst.plans.logical.{ Join, JoinHint }
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.types.StructType
import shapeless._
import shapeless.labelled.FieldType
-import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler}
-import shapeless.ops.record.{Keys, Modifier, Remover, Values}
+import shapeless.ops.hlist.{
+ Diff,
+ IsHCons,
+ Mapper,
+ Prepend,
+ ToTraversable,
+ Tupler
+}
+import shapeless.ops.record.{ Keys, Modifier, Remover, Values }
import scala.language.experimental.macros
-/** [[TypedDataset]] is a safer interface for working with `Dataset`.
- *
- * NOTE: Prefer `TypedDataset.create` over `new TypedDataset` unless you
- * know what you are doing.
- *
- * Documentation marked "apache/spark" is thanks to apache/spark Contributors
- * at https://github.com/apache/spark, licensed under Apache v2.0 available at
- * http://www.apache.org/licenses/LICENSE-2.0
- */
-class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val encoder: TypedEncoder[T])
+/**
+ * [[TypedDataset]] is a safer interface for working with `Dataset`.
+ *
+ * NOTE: Prefer `TypedDataset.create` over `new TypedDataset` unless you
+ * know what you are doing.
+ *
+ * Documentation marked "apache/spark" is thanks to apache/spark Contributors
+ * at https://github.com/apache/spark, licensed under Apache v2.0 available at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+class TypedDataset[T] protected[frameless] (
+ val dataset: Dataset[T]
+ )(implicit
+ val encoder: TypedEncoder[T])
extends TypedDatasetForwarded[T] { self =>
private implicit val spark: SparkSession = dataset.sparkSession
- /** Aggregates on the entire Dataset without groups.
- *
- * apache/spark
- */
+ /**
+ * Aggregates on the entire Dataset without groups.
+ *
+ * apache/spark
+ */
def agg[A](ca: TypedAggregate[T, A]): TypedDataset[A] = {
implicit val ea = ca.uencoder
val tuple1: TypedDataset[Tuple1[A]] = aggMany(ca)
@@ -42,10 +64,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
TypedEncoder[A].catalystRepr match {
case StructType(_) =>
// if column is struct, we use all its fields
- val df = tuple1
- .dataset
- .selectExpr("_1.*")
- .as[A](TypedExpressionEncoder[A])
+ val df =
+ tuple1.dataset.selectExpr("_1.*").as[A](TypedExpressionEncoder[A])
TypedDataset.create(df)
case other =>
@@ -54,52 +74,59 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
}
- /** Aggregates on the entire Dataset without groups.
- *
- * apache/spark
- */
+ /**
+ * Aggregates on the entire Dataset without groups.
+ *
+ * apache/spark
+ */
def agg[A, B](
- ca: TypedAggregate[T, A],
- cb: TypedAggregate[T, B]
- ): TypedDataset[(A, B)] = {
+ ca: TypedAggregate[T, A],
+ cb: TypedAggregate[T, B]
+ ): TypedDataset[(A, B)] = {
implicit val (ea, eb) = (ca.uencoder, cb.uencoder)
aggMany(ca, cb)
}
- /** Aggregates on the entire Dataset without groups.
- *
- * apache/spark
- */
+ /**
+ * Aggregates on the entire Dataset without groups.
+ *
+ * apache/spark
+ */
def agg[A, B, C](
- ca: TypedAggregate[T, A],
- cb: TypedAggregate[T, B],
- cc: TypedAggregate[T, C]
- ): TypedDataset[(A, B, C)] = {
+ ca: TypedAggregate[T, A],
+ cb: TypedAggregate[T, B],
+ cc: TypedAggregate[T, C]
+ ): TypedDataset[(A, B, C)] = {
implicit val (ea, eb, ec) = (ca.uencoder, cb.uencoder, cc.uencoder)
aggMany(ca, cb, cc)
}
- /** Aggregates on the entire Dataset without groups.
- *
- * apache/spark
- */
+ /**
+ * Aggregates on the entire Dataset without groups.
+ *
+ * apache/spark
+ */
def agg[A, B, C, D](
- ca: TypedAggregate[T, A],
- cb: TypedAggregate[T, B],
- cc: TypedAggregate[T, C],
- cd: TypedAggregate[T, D]
- ): TypedDataset[(A, B, C, D)] = {
- implicit val (ea, eb, ec, ed) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder)
+ ca: TypedAggregate[T, A],
+ cb: TypedAggregate[T, B],
+ cc: TypedAggregate[T, C],
+ cd: TypedAggregate[T, D]
+ ): TypedDataset[(A, B, C, D)] = {
+ implicit val (ea, eb, ec, ed) =
+ (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder)
aggMany(ca, cb, cc, cd)
}
- /** Aggregates on the entire Dataset without groups.
- *
- * apache/spark
- */
+ /**
+ * Aggregates on the entire Dataset without groups.
+ *
+ * apache/spark
+ */
object aggMany extends ProductArgs {
- def applyProduct[U <: HList, Out0 <: HList, Out](columns: U)
- (implicit
+
+ def applyProduct[U <: HList, Out0 <: HList, Out](
+ columns: U
+ )(implicit
i0: AggregateTypes.Aux[T, U, Out0],
i1: ToTraversable.Aux[U, List, UntypedExpression[T]],
i2: Tupler.Aux[Out0, Out],
@@ -109,7 +136,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
val underlyingColumns = columns.toList[UntypedExpression[T]]
val cols: Seq[Column] = for {
(c, i) <- columns.toList[UntypedExpression[T]].zipWithIndex
- } yield new Column(c.expr).as(s"_${i+1}")
+ } yield new Column(c.expr).as(s"_${i + 1}")
// Workaround to SPARK-20346. One alternative is to allow the result to be Vector(null) for empty DataFrames.
// Another one would be to return an Option.
@@ -117,129 +144,163 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
for {
(c, i) <- underlyingColumns.zipWithIndex
if !c.uencoder.nullable
- } yield s"_${i+1} is not null"
- ).mkString(" or ")
+ } yield s"_${i + 1} is not null"
+ ).mkString(" or ")
- val selected = dataset.toDF().agg(cols.head, cols.tail:_*).as[Out](TypedExpressionEncoder[Out])
- TypedDataset.create[Out](if (filterStr.isEmpty) selected else selected.filter(filterStr))
+ val selected = dataset
+ .toDF()
+ .agg(cols.head, cols.tail: _*)
+ .as[Out](TypedExpressionEncoder[Out])
+ TypedDataset.create[Out](
+ if (filterStr.isEmpty) selected else selected.filter(filterStr)
+ )
}
}
/** Returns a new [[TypedDataset]] where each record has been mapped on to the specified type. */
- def as[U]()(implicit as: As[T, U]): TypedDataset[U] = {
+ def as[U](
+ )(implicit
+ as: As[T, U]
+ ): TypedDataset[U] = {
implicit val uencoder = as.encoder
TypedDataset.create(dataset.as[U](TypedExpressionEncoder[U]))
}
- /** Returns a checkpointed version of this [[TypedDataset]]. Checkpointing can be used to truncate the
- * logical plan of this Dataset, which is especially useful in iterative algorithms where the
- * plan may grow exponentially. It will be saved to files inside the checkpoint
- * directory set with `SparkContext#setCheckpointDir`.
- *
- * Differs from `Dataset#checkpoint` by wrapping its result into an effect-suspending `F[_]`.
- *
- * apache/spark
- */
- def checkpoint[F[_]](eager: Boolean)(implicit F: SparkDelay[F]): F[TypedDataset[T]] =
+ /**
+ * Returns a checkpointed version of this [[TypedDataset]]. Checkpointing can be used to truncate the
+ * logical plan of this Dataset, which is especially useful in iterative algorithms where the
+ * plan may grow exponentially. It will be saved to files inside the checkpoint
+ * directory set with `SparkContext#setCheckpointDir`.
+ *
+ * Differs from `Dataset#checkpoint` by wrapping its result into an effect-suspending `F[_]`.
+ *
+ * apache/spark
+ */
+ def checkpoint[F[_]](
+ eager: Boolean
+ )(implicit
+ F: SparkDelay[F]
+ ): F[TypedDataset[T]] =
F.delay(TypedDataset.create[T](dataset.checkpoint(eager)))
- /** Returns a new [[TypedDataset]] where each record has been mapped on to the specified type.
- * Unlike `as` the projection U may include a subset of the columns of T and the column names and types must agree.
- *
- * {{{
- * case class Foo(i: Int, j: String)
- * case class Bar(j: String)
- *
- * val t: TypedDataset[Foo] = ...
- * val b: TypedDataset[Bar] = t.project[Bar]
- *
- * case class BarErr(e: String)
- * // The following does not compile because `Foo` doesn't have a field with name `e`
- * val e: TypedDataset[BarErr] = t.project[BarErr]
- * }}}
- */
- def project[U](implicit projector: SmartProject[T,U]): TypedDataset[U] = projector.apply(this)
-
- /** Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]]
- * combined.
- *
- * Note that, this function is not a typical set union operation, in that it does not eliminate
- * duplicate items. As such, it is analogous to `UNION ALL` in SQL.
- *
- * Differs from `Dataset#union` by aligning fields if possible.
- * It will not compile if `Datasets` have not compatible schema.
- *
- * Example:
- * {{{
- * case class Foo(x: Int, y: Long)
- * case class Bar(y: Long, x: Int)
- * case class Faz(x: Int, y: Int, z: Int)
- *
- * foo: TypedDataset[Foo] = ...
- * bar: TypedDataset[Bar] = ...
- * faz: TypedDataset[Faz] = ...
- *
- * foo union bar: TypedDataset[Foo]
- * foo union faz: TypedDataset[Foo]
- * // won't compile, you need to reverse order, you can't project from less fields to more
- * faz union foo
- *
- * }}}
- *
- * apache/spark
- */
- def union[U: TypedEncoder](other: TypedDataset[U])(implicit projector: SmartProject[U, T]): TypedDataset[T] =
+ /**
+ * Returns a new [[TypedDataset]] where each record has been mapped on to the specified type.
+ * Unlike `as` the projection U may include a subset of the columns of T and the column names and types must agree.
+ *
+ * {{{
+ * case class Foo(i: Int, j: String)
+ * case class Bar(j: String)
+ *
+ * val t: TypedDataset[Foo] = ...
+ * val b: TypedDataset[Bar] = t.project[Bar]
+ *
+ * case class BarErr(e: String)
+ * // The following does not compile because `Foo` doesn't have a field with name `e`
+ * val e: TypedDataset[BarErr] = t.project[BarErr]
+ * }}}
+ */
+ def project[U](
+ implicit
+ projector: SmartProject[T, U]
+ ): TypedDataset[U] = projector.apply(this)
+
+ /**
+ * Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]]
+ * combined.
+ *
+ * Note that, this function is not a typical set union operation, in that it does not eliminate
+ * duplicate items. As such, it is analogous to `UNION ALL` in SQL.
+ *
+ * Differs from `Dataset#union` by aligning fields if possible.
+ * It will not compile if `Datasets` have not compatible schema.
+ *
+ * Example:
+ * {{{
+ * case class Foo(x: Int, y: Long)
+ * case class Bar(y: Long, x: Int)
+ * case class Faz(x: Int, y: Int, z: Int)
+ *
+ * foo: TypedDataset[Foo] = ...
+ * bar: TypedDataset[Bar] = ...
+ * faz: TypedDataset[Faz] = ...
+ *
+ * foo union bar: TypedDataset[Foo]
+ * foo union faz: TypedDataset[Foo]
+ * // won't compile, you need to reverse order, you can't project from less fields to more
+ * faz union foo
+ *
+ * }}}
+ *
+ * apache/spark
+ */
+ def union[U: TypedEncoder](
+ other: TypedDataset[U]
+ )(implicit
+ projector: SmartProject[U, T]
+ ): TypedDataset[T] =
TypedDataset.create(dataset.union(other.project[T].dataset))
- /** Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]]
- * combined.
- *
- * Note that, this function is not a typical set union operation, in that it does not eliminate
- * duplicate items. As such, it is analogous to `UNION ALL` in SQL.
- *
- * apache/spark
- */
+ /**
+ * Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]]
+ * combined.
+ *
+ * Note that, this function is not a typical set union operation, in that it does not eliminate
+ * duplicate items. As such, it is analogous to `UNION ALL` in SQL.
+ *
+ * apache/spark
+ */
def union(other: TypedDataset[T]): TypedDataset[T] = {
TypedDataset.create(dataset.union(other.dataset))
}
- /** Returns the number of elements in the [[TypedDataset]].
- *
- * Differs from `Dataset#count` by wrapping its result into an effect-suspending `F[_]`.
- */
- def count[F[_]]()(implicit F: SparkDelay[F]): F[Long] =
+ /**
+ * Returns the number of elements in the [[TypedDataset]].
+ *
+ * Differs from `Dataset#count` by wrapping its result into an effect-suspending `F[_]`.
+ */
+ def count[F[_]](
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Long] =
F.delay(dataset.count())
- /** Returns `TypedColumn` of type `A` given its name (alias for `col`).
- *
- * {{{
- * tf('id)
- * }}}
- *
- * It is statically checked that column with such name exists and has type `A`.
- */
- def apply[A](column: Witness.Lt[Symbol])
- (implicit
+ /**
+ * Returns `TypedColumn` of type `A` given its name (alias for `col`).
+ *
+ * {{{
+ * tf('id)
+ * }}}
+ *
+ * It is statically checked that column with such name exists and has type `A`.
+ */
+ def apply[A](
+ column: Witness.Lt[Symbol]
+ )(implicit
i0: TypedColumn.Exists[T, column.T, A],
i1: TypedEncoder[A]
): TypedColumn[T, A] = col(column)
- /** Returns `TypedColumn` of type `A` given its name.
- *
- * {{{
- * tf.col('id)
- * }}}
- *
- * It is statically checked that column with such name exists and has type `A`.
- */
- def col[A](column: Witness.Lt[Symbol])
- (implicit
+ /**
+ * Returns `TypedColumn` of type `A` given its name.
+ *
+ * {{{
+ * tf.col('id)
+ * }}}
+ *
+ * It is statically checked that column with such name exists and has type `A`.
+ */
+ def col[A](
+ column: Witness.Lt[Symbol]
+ )(implicit
i0: TypedColumn.Exists[T, column.T, A],
i1: TypedEncoder[A]
): TypedColumn[T, A] =
- new TypedColumn[T, A](dataset(column.value.name).as[A](TypedExpressionEncoder[A]))
+ new TypedColumn[T, A](
+ dataset(column.value.name).as[A](TypedExpressionEncoder[A])
+ )
- /** Returns `TypedColumn` of type `A` given a lambda indicating the field.
+ /**
+ * Returns `TypedColumn` of type `A` given a lambda indicating the field.
*
* {{{
* td.col(_.id)
@@ -250,12 +311,13 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
def col[A](x: Function1[T, A]): TypedColumn[T, A] =
macro TypedColumnMacroImpl.applyImpl[T, A]
- /** Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`.
- * {{{
- * ts: TypedDataset[Foo] = ...
- * ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)]
- * }}}
- */
+ /**
+ * Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`.
+ * {{{
+ * ts: TypedDataset[Foo] = ...
+ * ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)]
+ * }}}
+ */
def asCol: TypedColumn[T, T] = {
val projectedColumn: Column = encoder.catalystRepr match {
case StructType(_) =>
@@ -265,78 +327,98 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
case _ =>
dataset.col(dataset.columns.head)
}
-
- new TypedColumn[T,T](projectedColumn)
+
+ new TypedColumn[T, T](projectedColumn)
}
- /** References the entire `TypedDataset[T]` as a single column
- * of type `TypedColumn[T,T]` so it can be used in a join operation.
- *
- * {{{
- * def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) =
- * ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue)
- * }}}
- */
- def asJoinColValue(implicit i0: IsValueClass[T]): TypedColumn[T, T] = {
+ /**
+ * References the entire `TypedDataset[T]` as a single column
+ * of type `TypedColumn[T,T]` so it can be used in a join operation.
+ *
+ * {{{
+ * def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) =
+ * ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue)
+ * }}}
+ */
+ def asJoinColValue(
+ implicit
+ i0: IsValueClass[T]
+ ): TypedColumn[T, T] = {
import _root_.frameless.syntax._
dataset.col("value").typedColumn
}
object colMany extends SingletonProductArgs {
- def applyProduct[U <: HList, Out](columns: U)
- (implicit
+
+ def applyProduct[U <: HList, Out](
+ columns: U
+ )(implicit
i0: TypedColumn.ExistsMany[T, U, Out],
i1: TypedEncoder[Out],
i2: ToTraversable.Aux[U, List, Symbol]
): TypedColumn[T, Out] = {
- val names = columns.toList[Symbol].map(_.name)
- val colExpr = FramelessInternals.resolveExpr(dataset, names)
- new TypedColumn[T, Out](colExpr)
- }
+ val names = columns.toList[Symbol].map(_.name)
+ val colExpr = FramelessInternals.resolveExpr(dataset, names)
+ new TypedColumn[T, Out](colExpr)
+ }
}
- /** Right hand side disambiguation of `col` for join expressions.
- * To be used when writting self-joins, noop in other circumstances.
- *
- * Note: In vanilla Spark, disambiguation in self-joins is acheaved using
- * String based aliases, which is obviously unsafe.
- */
- def colRight[A](column: Witness.Lt[Symbol])
- (implicit
+ /**
+ * Right hand side disambiguation of `col` for join expressions.
+ * To be used when writting self-joins, noop in other circumstances.
+ *
+ * Note: In vanilla Spark, disambiguation in self-joins is acheaved using
+ * String based aliases, which is obviously unsafe.
+ */
+ def colRight[A](
+ column: Witness.Lt[Symbol]
+ )(implicit
i0: TypedColumn.Exists[T, column.T, A],
i1: TypedEncoder[A]
): TypedColumn[T, A] =
- new TypedColumn[T, A](FramelessInternals.DisambiguateRight(col(column).expr))
-
- /** Left hand side disambiguation of `col` for join expressions.
- * To be used when writting self-joins, noop in other circumstances.
- *
- * Note: In vanilla Spark, disambiguation in self-joins is acheaved using
- * String based aliases, which is obviously unsafe.
- */
- def colLeft[A](column: Witness.Lt[Symbol])
- (implicit
+ new TypedColumn[T, A](
+ FramelessInternals.DisambiguateRight(col(column).expr)
+ )
+
+ /**
+ * Left hand side disambiguation of `col` for join expressions.
+ * To be used when writting self-joins, noop in other circumstances.
+ *
+ * Note: In vanilla Spark, disambiguation in self-joins is acheaved using
+ * String based aliases, which is obviously unsafe.
+ */
+ def colLeft[A](
+ column: Witness.Lt[Symbol]
+ )(implicit
i0: TypedColumn.Exists[T, column.T, A],
i1: TypedEncoder[A]
): TypedColumn[T, A] =
- new TypedColumn[T, A](FramelessInternals.DisambiguateLeft(col(column).expr))
-
- /** Returns a `Seq` that contains all the elements in this [[TypedDataset]].
- *
- * Running this operation requires moving all the data into the application's driver process, and
- * doing so on a very large [[TypedDataset]] can crash the driver process with OutOfMemoryError.
- *
- * Differs from `Dataset#collect` by wrapping its result into an effect-suspending `F[_]`.
- */
- def collect[F[_]]()(implicit F: SparkDelay[F]): F[Seq[T]] =
+ new TypedColumn[T, A](FramelessInternals.DisambiguateLeft(col(column).expr))
+
+ /**
+ * Returns a `Seq` that contains all the elements in this [[TypedDataset]].
+ *
+ * Running this operation requires moving all the data into the application's driver process, and
+ * doing so on a very large [[TypedDataset]] can crash the driver process with OutOfMemoryError.
+ *
+ * Differs from `Dataset#collect` by wrapping its result into an effect-suspending `F[_]`.
+ */
+ def collect[F[_]](
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Seq[T]] =
F.delay(dataset.collect().toSeq)
- /** Optionally returns the first element in this [[TypedDataset]].
- *
- * Differs from `Dataset#first` by wrapping its result into an `Option` and an effect-suspending `F[_]`.
- */
- def firstOption[F[_]]()(implicit F: SparkDelay[F]): F[Option[T]] =
+ /**
+ * Optionally returns the first element in this [[TypedDataset]].
+ *
+ * Differs from `Dataset#first` by wrapping its result into an `Option` and an effect-suspending `F[_]`.
+ */
+ def firstOption[F[_]](
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Option[T]] =
F.delay {
try {
Option(dataset.first())
@@ -345,354 +427,462 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
}
- /** Returns the first `num` elements of this [[TypedDataset]] as a `Seq`.
- *
- * Running take requires moving data into the application's driver process, and doing so with
- * a very large `num` can crash the driver process with OutOfMemoryError.
- *
- * Differs from `Dataset#take` by wrapping its result into an effect-suspending `F[_]`.
- *
- * apache/spark
- */
- def take[F[_]](num: Int)(implicit F: SparkDelay[F]): F[Seq[T]] =
+ /**
+ * Returns the first `num` elements of this [[TypedDataset]] as a `Seq`.
+ *
+ * Running take requires moving data into the application's driver process, and doing so with
+ * a very large `num` can crash the driver process with OutOfMemoryError.
+ *
+ * Differs from `Dataset#take` by wrapping its result into an effect-suspending `F[_]`.
+ *
+ * apache/spark
+ */
+ def take[F[_]](
+ num: Int
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Seq[T]] =
F.delay(dataset.take(num).toSeq)
- /** Return an iterator that contains all rows in this [[TypedDataset]].
- *
- * The iterator will consume as much memory as the largest partition in this [[TypedDataset]].
- *
- * NOTE: this results in multiple Spark jobs, and if the input [[TypedDataset]] is the result
- * of a wide transformation (e.g. join with different partitioners), to avoid
- * recomputing the input [[TypedDataset]] should be cached first.
- *
- * Differs from `Dataset#toLocalIterator()` by wrapping its result into an effect-suspending `F[_]`.
- *
- * apache/spark
- */
- def toLocalIterator[F[_]]()(implicit F: SparkDelay[F]): F[util.Iterator[T]] =
+ /**
+ * Return an iterator that contains all rows in this [[TypedDataset]].
+ *
+ * The iterator will consume as much memory as the largest partition in this [[TypedDataset]].
+ *
+ * NOTE: this results in multiple Spark jobs, and if the input [[TypedDataset]] is the result
+ * of a wide transformation (e.g. join with different partitioners), to avoid
+ * recomputing the input [[TypedDataset]] should be cached first.
+ *
+ * Differs from `Dataset#toLocalIterator()` by wrapping its result into an effect-suspending `F[_]`.
+ *
+ * apache/spark
+ */
+ def toLocalIterator[F[_]](
+ )(implicit
+ F: SparkDelay[F]
+ ): F[util.Iterator[T]] =
F.delay(dataset.toLocalIterator())
- /** Alias for firstOption().
- */
- def headOption[F[_]]()(implicit F: SparkDelay[F]): F[Option[T]] = firstOption()
+ /**
+ * Alias for firstOption().
+ */
+ def headOption[F[_]](
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Option[T]] = firstOption()
- /** Alias for take().
- */
- def head[F[_]](num: Int)(implicit F: SparkDelay[F]): F[Seq[T]] = take(num)
+ /**
+ * Alias for take().
+ */
+ def head[F[_]](
+ num: Int
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Seq[T]] = take(num)
// $COVERAGE-OFF$
- /** Alias for firstOption().
- */
- @deprecated("Method may throw exception. Use headOption or firstOption instead.", "0.5.0")
+ /**
+ * Alias for firstOption().
+ */
+ @deprecated(
+ "Method may throw exception. Use headOption or firstOption instead.",
+ "0.5.0"
+ )
def head: T = dataset.head()
- /** Alias for firstOption().
- */
- @deprecated("Method may throw exception. Use headOption or firstOption instead.", "0.5.0")
+ /**
+ * Alias for firstOption().
+ */
+ @deprecated(
+ "Method may throw exception. Use headOption or firstOption instead.",
+ "0.5.0"
+ )
def first: T = dataset.head()
// $COVERAGE-ONN$
- /** Displays the content of this [[TypedDataset]] in a tabular form. Strings more than 20 characters
- * will be truncated, and all cells will be aligned right. For example:
- * {{{
- * year month AVG('Adj Close) MAX('Adj Close)
- * 1980 12 0.503218 0.595103
- * 1981 01 0.523289 0.570307
- * 1982 02 0.436504 0.475256
- * 1983 03 0.410516 0.442194
- * 1984 04 0.450090 0.483521
- * }}}
- * @param numRows Number of rows to show
- * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
- * be truncated and all cells will be aligned right
- *
- * Differs from `Dataset#show` by wrapping its result into an effect-suspending `F[_]`.
- *
- * apache/spark
- */
- def show[F[_]](numRows: Int = 20, truncate: Boolean = true)(implicit F: SparkDelay[F]): F[Unit] =
+ /**
+ * Displays the content of this [[TypedDataset]] in a tabular form. Strings more than 20 characters
+ * will be truncated, and all cells will be aligned right. For example:
+ * {{{
+ * year month AVG('Adj Close) MAX('Adj Close)
+ * 1980 12 0.503218 0.595103
+ * 1981 01 0.523289 0.570307
+ * 1982 02 0.436504 0.475256
+ * 1983 03 0.410516 0.442194
+ * 1984 04 0.450090 0.483521
+ * }}}
+ * @param numRows Number of rows to show
+ * @param truncate Whether truncate long strings. If true, strings more than 20 characters will
+ * be truncated and all cells will be aligned right
+ *
+ * Differs from `Dataset#show` by wrapping its result into an effect-suspending `F[_]`.
+ *
+ * apache/spark
+ */
+ def show[F[_]](
+ numRows: Int = 20,
+ truncate: Boolean = true
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Unit] =
F.delay(dataset.show(numRows, truncate))
- /** Returns a new [[frameless.TypedDataset]] that only contains elements where `column` is `true`.
- *
- * Differs from `TypedDatasetForward#filter` by taking a `TypedColumn[T, Boolean]` instead of a
- * `T => Boolean`. Using a column expression instead of a regular function save one Spark → Scala
- * deserialization which leads to better performance.
- */
+ /**
+ * Returns a new [[frameless.TypedDataset]] that only contains elements where `column` is `true`.
+ *
+ * Differs from `TypedDatasetForward#filter` by taking a `TypedColumn[T, Boolean]` instead of a
+ * `T => Boolean`. Using a column expression instead of a regular function save one Spark → Scala
+ * deserialization which leads to better performance.
+ */
def filter(column: TypedColumn[T, Boolean]): TypedDataset[T] = {
- val filtered = dataset.toDF()
- .filter(column.untyped)
- .as[T](TypedExpressionEncoder[T])
+ val filtered =
+ dataset.toDF().filter(column.untyped).as[T](TypedExpressionEncoder[T])
TypedDataset.create[T](filtered)
}
- /** Runs `func` on each element of this [[TypedDataset]].
- *
- * Differs from `Dataset#foreach` by wrapping its result into an effect-suspending `F[_]`.
- */
- def foreach[F[_]](func: T => Unit)(implicit F: SparkDelay[F]): F[Unit] =
+ /**
+ * Runs `func` on each element of this [[TypedDataset]].
+ *
+ * Differs from `Dataset#foreach` by wrapping its result into an effect-suspending `F[_]`.
+ */
+ def foreach[F[_]](
+ func: T => Unit
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Unit] =
F.delay(dataset.foreach(func))
- /** Runs `func` on each partition of this [[TypedDataset]].
- *
- * Differs from `Dataset#foreachPartition` by wrapping its result into an effect-suspending `F[_]`.
- */
- def foreachPartition[F[_]](func: Iterator[T] => Unit)(implicit F: SparkDelay[F]): F[Unit] =
+ /**
+ * Runs `func` on each partition of this [[TypedDataset]].
+ *
+ * Differs from `Dataset#foreachPartition` by wrapping its result into an effect-suspending `F[_]`.
+ */
+ def foreachPartition[F[_]](
+ func: Iterator[T] => Unit
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Unit] =
F.delay(dataset.foreachPartition(func))
/**
- * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified column,
- * so we can run aggregation on it.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`.
- *
- * apache/spark
- */
+ * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified column,
+ * so we can run aggregation on it.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`.
+ *
+ * apache/spark
+ */
def cube[K1](
- c1: TypedColumn[T, K1]
- ): Cube1Ops[K1, T] = new Cube1Ops[K1, T](this, c1)
-
- /**
- * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns,
- * so we can run aggregation on them.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`.
- *
- * apache/spark
- */
+ c1: TypedColumn[T, K1]
+ ): Cube1Ops[K1, T] = new Cube1Ops[K1, T](this, c1)
+
+ /**
+ * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`.
+ *
+ * apache/spark
+ */
def cube[K1, K2](
- c1: TypedColumn[T, K1],
- c2: TypedColumn[T, K2]
- ): Cube2Ops[K1, K2, T] = new Cube2Ops[K1, K2, T](this, c1, c2)
-
- /**
- * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns,
- * so we can run aggregation on them.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * {{{
- * case class MyClass(a: Int, b: Int, c: Int)
- * val ds: TypedDataset[MyClass]
-
- * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] =
- * ds.cubeMany(ds('a), ds('b)).agg(count[MyClass]())
- *
- * // original dataset:
- * a b c
- * 10 20 1
- * 15 25 2
- *
- * // after aggregation:
- * _1 _2 _3
- * 15 null 1
- * 15 25 1
- * null null 2
- * null 25 1
- * null 20 1
- * 10 null 1
- * 10 20 1
- *
- * }}}
- *
- * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`.
- *
- * apache/spark
- */
+ c1: TypedColumn[T, K1],
+ c2: TypedColumn[T, K2]
+ ): Cube2Ops[K1, K2, T] = new Cube2Ops[K1, K2, T](this, c1, c2)
+
+ /**
+ * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * {{{
+ * case class MyClass(a: Int, b: Int, c: Int)
+ * val ds: TypedDataset[MyClass]
+ *
+ * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] =
+ * ds.cubeMany(ds('a), ds('b)).agg(count[MyClass]())
+ *
+ * // original dataset:
+ * a b c
+ * 10 20 1
+ * 15 25 2
+ *
+ * // after aggregation:
+ * _1 _2 _3
+ * 15 null 1
+ * 15 25 1
+ * null null 2
+ * null 25 1
+ * null 20 1
+ * 10 null 1
+ * 10 20 1
+ *
+ * }}}
+ *
+ * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`.
+ *
+ * apache/spark
+ */
object cubeMany extends ProductArgs {
- def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK)
- (implicit
+
+ def applyProduct[TK <: HList, K <: HList, KT](
+ groupedBy: TK
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: Tupler.Aux[K, KT],
i2: ToTraversable.Aux[TK, List, UntypedExpression[T]]
- ): CubeManyOps[T, TK, K, KT] = new CubeManyOps[T, TK, K, KT](self, groupedBy)
+ ): CubeManyOps[T, TK, K, KT] =
+ new CubeManyOps[T, TK, K, KT](self, groupedBy)
}
/**
- * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * apache/spark
- */
+ * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * apache/spark
+ */
def groupBy[K1](
- c1: TypedColumn[T, K1]
- ): GroupedBy1Ops[K1, T] = new GroupedBy1Ops[K1, T](this, c1)
+ c1: TypedColumn[T, K1]
+ ): GroupedBy1Ops[K1, T] = new GroupedBy1Ops[K1, T](this, c1)
/**
- * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * apache/spark
- */
+ * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * apache/spark
+ */
def groupBy[K1, K2](
- c1: TypedColumn[T, K1],
- c2: TypedColumn[T, K2]
- ): GroupedBy2Ops[K1, K2, T] = new GroupedBy2Ops[K1, K2, T](this, c1, c2)
-
- /**
- * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * {{{
- * case class MyClass(a: Int, b: Int, c: Int)
- * val ds: TypedDataset[MyClass]
- *
- * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] =
- * ds.groupByMany(ds('a), ds('b)).agg(count[MyClass]())
- *
- * // original dataset:
- * a b c
- * 10 20 1
- * 15 25 2
- *
- * // after aggregation:
- * _1 _2 _3
- * 10 20 1
- * 15 25 1
- *
- * }}}
- *
- * apache/spark
- */
+ c1: TypedColumn[T, K1],
+ c2: TypedColumn[T, K2]
+ ): GroupedBy2Ops[K1, K2, T] = new GroupedBy2Ops[K1, K2, T](this, c1, c2)
+
+ /**
+ * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * {{{
+ * case class MyClass(a: Int, b: Int, c: Int)
+ * val ds: TypedDataset[MyClass]
+ *
+ * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] =
+ * ds.groupByMany(ds('a), ds('b)).agg(count[MyClass]())
+ *
+ * // original dataset:
+ * a b c
+ * 10 20 1
+ * 15 25 2
+ *
+ * // after aggregation:
+ * _1 _2 _3
+ * 10 20 1
+ * 15 25 1
+ *
+ * }}}
+ *
+ * apache/spark
+ */
object groupByMany extends ProductArgs {
- def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK)
- (implicit
+
+ def applyProduct[TK <: HList, K <: HList, KT](
+ groupedBy: TK
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: Tupler.Aux[K, KT],
i2: ToTraversable.Aux[TK, List, UntypedExpression[T]]
- ): GroupedByManyOps[T, TK, K, KT] = new GroupedByManyOps[T, TK, K, KT](self, groupedBy)
+ ): GroupedByManyOps[T, TK, K, KT] =
+ new GroupedByManyOps[T, TK, K, KT](self, groupedBy)
}
/**
- * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified column,
- * so we can run aggregation on it.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`.
- *
- * apache/spark
- */
+ * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified column,
+ * so we can run aggregation on it.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`.
+ *
+ * apache/spark
+ */
def rollup[K1](
- c1: TypedColumn[T, K1]
- ): Rollup1Ops[K1, T] = new Rollup1Ops[K1, T](this, c1)
-
- /**
- * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns,
- * so we can run aggregation on them.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`.
- *
- * apache/spark
- */
+ c1: TypedColumn[T, K1]
+ ): Rollup1Ops[K1, T] = new Rollup1Ops[K1, T](this, c1)
+
+ /**
+ * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`.
+ *
+ * apache/spark
+ */
def rollup[K1, K2](
- c1: TypedColumn[T, K1],
- c2: TypedColumn[T, K2]
- ): Rollup2Ops[K1, K2, T] = new Rollup2Ops[K1, K2, T](this, c1, c2)
-
- /**
- * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns,
- * so we can run aggregation on them.
- * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
- *
- * {{{
- * case class MyClass(a: Int, b: Int, c: Int)
- * val ds: TypedDataset[MyClass]
- *
- * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] =
- * ds.rollupMany(ds('a), ds('b)).agg(count[MyClass]())
- *
- * // original dataset:
- * a b c
- * 10 20 1
- * 15 25 2
- *
- * // after aggregation:
- * _1 _2 _3
- * 15 null 1
- * 15 25 1
- * null null 2
- * 10 null 1
- * 10 20 1
- *
- * }}}
- *
- * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`.
- *
- * apache/spark
- */
+ c1: TypedColumn[T, K1],
+ c2: TypedColumn[T, K2]
+ ): Rollup2Ops[K1, K2, T] = new Rollup2Ops[K1, K2, T](this, c1, c2)
+
+ /**
+ * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns,
+ * so we can run aggregation on them.
+ * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions.
+ *
+ * {{{
+ * case class MyClass(a: Int, b: Int, c: Int)
+ * val ds: TypedDataset[MyClass]
+ *
+ * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] =
+ * ds.rollupMany(ds('a), ds('b)).agg(count[MyClass]())
+ *
+ * // original dataset:
+ * a b c
+ * 10 20 1
+ * 15 25 2
+ *
+ * // after aggregation:
+ * _1 _2 _3
+ * 15 null 1
+ * 15 25 1
+ * null null 2
+ * 10 null 1
+ * 10 20 1
+ *
+ * }}}
+ *
+ * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`.
+ *
+ * apache/spark
+ */
object rollupMany extends ProductArgs {
- def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK)
- (implicit
+
+ def applyProduct[TK <: HList, K <: HList, KT](
+ groupedBy: TK
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: Tupler.Aux[K, KT],
i2: ToTraversable.Aux[TK, List, UntypedExpression[T]]
- ): RollupManyOps[T, TK, K, KT] = new RollupManyOps[T, TK, K, KT](self, groupedBy)
+ ): RollupManyOps[T, TK, K, KT] =
+ new RollupManyOps[T, TK, K, KT](self, groupedBy)
}
/** Computes the cartesian project of `this` `Dataset` with the `other` `Dataset` */
- def joinCross[U](other: TypedDataset[U])
- (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] =
- new TypedDataset(self.dataset.joinWith(other.dataset, new Column(Literal(true)), "cross"))
-
- /** Computes the full outer join of `this` `Dataset` with the `other` `Dataset`,
- * returning a `Tuple2` for each pair where condition evaluates to true.
- */
- def joinFull[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean])
- (implicit e: TypedEncoder[(Option[T], Option[U])]): TypedDataset[(Option[T], Option[U])] =
- new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "full")
- .as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])]))
-
- /** Computes the inner join of `this` `Dataset` with the `other` `Dataset`,
- * returning a `Tuple2` for each pair where condition evaluates to true.
- */
- def joinInner[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean])
- (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = {
- import FramelessInternals._
-
- val leftPlan = logicalPlan(dataset)
- val rightPlan = logicalPlan(other.dataset)
- val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE))
- val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
- val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)])
-
- TypedDataset.create[(T, U)](joinedDs)
- }
+ def joinCross[U](
+ other: TypedDataset[U]
+ )(implicit
+ e: TypedEncoder[(T, U)]
+ ): TypedDataset[(T, U)] =
+ new TypedDataset(
+ self.dataset.joinWith(other.dataset, new Column(Literal(true)), "cross")
+ )
- /** Computes the left outer join of `this` `Dataset` with the `other` `Dataset`,
- * returning a `Tuple2` for each pair where condition evaluates to true.
- */
- def joinLeft[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean])
- (implicit e: TypedEncoder[(T, Option[U])]): TypedDataset[(T, Option[U])] =
- new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "left_outer")
- .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])]))
-
- /** Computes the left semi join of `this` `Dataset` with the `other` `Dataset`,
- * returning a `Tuple2` for each pair where condition evaluates to true.
- */
- def joinLeftSemi[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] =
- new TypedDataset(self.dataset.join(other.dataset, condition.untyped, "leftsemi")
- .as[T](TypedExpressionEncoder(encoder)))
-
- /** Computes the left anti join of `this` `Dataset` with the `other` `Dataset`,
- * returning a `Tuple2` for each pair where condition evaluates to true.
- */
- def joinLeftAnti[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] =
- new TypedDataset(self.dataset.join(other.dataset, condition.untyped, "leftanti")
- .as[T](TypedExpressionEncoder(encoder)))
-
- /** Computes the right outer join of `this` `Dataset` with the `other` `Dataset`,
- * returning a `Tuple2` for each pair where condition evaluates to true.
- */
- def joinRight[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean])
- (implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] =
- new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "right_outer")
- .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)]))
+ /**
+ * Computes the full outer join of `this` `Dataset` with the `other` `Dataset`,
+ * returning a `Tuple2` for each pair where condition evaluates to true.
+ */
+ def joinFull[U](
+ other: TypedDataset[U]
+ )(condition: TypedColumn[T with U, Boolean]
+ )(implicit
+ e: TypedEncoder[(Option[T], Option[U])]
+ ): TypedDataset[(Option[T], Option[U])] =
+ new TypedDataset(
+ self.dataset
+ .joinWith(other.dataset, condition.untyped, "full")
+ .as[(Option[T], Option[U])](
+ TypedExpressionEncoder[(Option[T], Option[U])]
+ )
+ )
+
+ /**
+ * Computes the inner join of `this` `Dataset` with the `other` `Dataset`,
+ * returning a `Tuple2` for each pair where condition evaluates to true.
+ */
+ def joinInner[U](
+ other: TypedDataset[U]
+ )(condition: TypedColumn[T with U, Boolean]
+ )(implicit
+ e: TypedEncoder[(T, U)]
+ ): TypedDataset[(T, U)] = {
+ import FramelessInternals._
+
+ val leftPlan = logicalPlan(dataset)
+ val rightPlan = logicalPlan(other.dataset)
+ val join = disambiguate(
+ Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE)
+ )
+ val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
+ val joinedDs =
+ mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)])
+
+ TypedDataset.create[(T, U)](joinedDs)
+ }
+
+ /**
+ * Computes the left outer join of `this` `Dataset` with the `other` `Dataset`,
+ * returning a `Tuple2` for each pair where condition evaluates to true.
+ */
+ def joinLeft[U](
+ other: TypedDataset[U]
+ )(condition: TypedColumn[T with U, Boolean]
+ )(implicit
+ e: TypedEncoder[(T, Option[U])]
+ ): TypedDataset[(T, Option[U])] =
+ new TypedDataset(
+ self.dataset
+ .joinWith(other.dataset, condition.untyped, "left_outer")
+ .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])
+ )
+
+ /**
+ * Computes the left semi join of `this` `Dataset` with the `other` `Dataset`,
+ * returning a `Tuple2` for each pair where condition evaluates to true.
+ */
+ def joinLeftSemi[U](
+ other: TypedDataset[U]
+ )(condition: TypedColumn[T with U, Boolean]
+ ): TypedDataset[T] =
+ new TypedDataset(
+ self.dataset
+ .join(other.dataset, condition.untyped, "leftsemi")
+ .as[T](TypedExpressionEncoder(encoder))
+ )
+
+ /**
+ * Computes the left anti join of `this` `Dataset` with the `other` `Dataset`,
+ * returning a `Tuple2` for each pair where condition evaluates to true.
+ */
+ def joinLeftAnti[U](
+ other: TypedDataset[U]
+ )(condition: TypedColumn[T with U, Boolean]
+ ): TypedDataset[T] =
+ new TypedDataset(
+ self.dataset
+ .join(other.dataset, condition.untyped, "leftanti")
+ .as[T](TypedExpressionEncoder(encoder))
+ )
+
+ /**
+ * Computes the right outer join of `this` `Dataset` with the `other` `Dataset`,
+ * returning a `Tuple2` for each pair where condition evaluates to true.
+ */
+ def joinRight[U](
+ other: TypedDataset[U]
+ )(condition: TypedColumn[T with U, Boolean]
+ )(implicit
+ e: TypedEncoder[(Option[T], U)]
+ ): TypedDataset[(Option[T], U)] =
+ new TypedDataset(
+ self.dataset
+ .joinWith(other.dataset, condition.untyped, "right_outer")
+ .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)])
+ )
private def disambiguate(join: Join): Join = {
- val plan = FramelessInternals.ofRows(dataset.sparkSession, join).queryExecution.analyzed.asInstanceOf[Join]
+ val plan = FramelessInternals
+ .ofRows(dataset.sparkSession, join)
+ .queryExecution
+ .analyzed
+ .asInstanceOf[Join]
val disambiguated = plan.condition.map(_.transform {
case FramelessInternals.DisambiguateLeft(tagged: AttributeReference) =>
val leftDs = FramelessInternals.ofRows(spark, plan.left)
@@ -707,43 +897,82 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
plan.copy(condition = disambiguated)
}
- /** Takes a function from A => R and converts it to a UDF for TypedColumn[T, A] => TypedColumn[T, R].
- */
- def makeUDF[A: TypedEncoder, R: TypedEncoder](f: A => R):
- TypedColumn[T, A] => TypedColumn[T, R] = functions.udf(f)
-
- /** Takes a function from (A1, A2) => R and converts it to a UDF for
- * (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R].
- */
- def makeUDF[A1: TypedEncoder, A2: TypedEncoder, R: TypedEncoder](f: (A1, A2) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = functions.udf(f)
-
- /** Takes a function from (A1, A2, A3) => R and converts it to a UDF for
- * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R].
- */
- def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = functions.udf(f)
-
- /** Takes a function from (A1, A2, A3, A4) => R and converts it to a UDF for
- * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R].
- */
- def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = functions.udf(f)
-
- /** Takes a function from (A1, A2, A3, A4, A5) => R and converts it to a UDF for
- * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R].
- */
- def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, A5: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = functions.udf(f)
-
- /** Type-safe projection from type T to Tuple1[A]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Takes a function from A => R and converts it to a UDF for TypedColumn[T, A] => TypedColumn[T, R].
+ */
+ def makeUDF[A: TypedEncoder, R: TypedEncoder](
+ f: A => R
+ ): TypedColumn[T, A] => TypedColumn[T, R] = functions.udf(f)
+
+ /**
+ * Takes a function from (A1, A2) => R and converts it to a UDF for
+ * (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R].
+ */
+ def makeUDF[A1: TypedEncoder, A2: TypedEncoder, R: TypedEncoder](
+ f: (A1, A2) => R
+ ): (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] =
+ functions.udf(f)
+
+ /**
+ * Takes a function from (A1, A2, A3) => R and converts it to a UDF for
+ * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R].
+ */
+ def makeUDF[
+ A1: TypedEncoder,
+ A2: TypedEncoder,
+ A3: TypedEncoder,
+ R: TypedEncoder
+ ](f: (A1, A2, A3) => R
+ ): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] =
+ functions.udf(f)
+
+ /**
+ * Takes a function from (A1, A2, A3, A4) => R and converts it to a UDF for
+ * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R].
+ */
+ def makeUDF[
+ A1: TypedEncoder,
+ A2: TypedEncoder,
+ A3: TypedEncoder,
+ A4: TypedEncoder,
+ R: TypedEncoder
+ ](f: (A1, A2, A3, A4) => R
+ ): (
+ TypedColumn[T, A1],
+ TypedColumn[T, A2],
+ TypedColumn[T, A3],
+ TypedColumn[T, A4]
+ ) => TypedColumn[T, R] = functions.udf(f)
+
+ /**
+ * Takes a function from (A1, A2, A3, A4, A5) => R and converts it to a UDF for
+ * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R].
+ */
+ def makeUDF[
+ A1: TypedEncoder,
+ A2: TypedEncoder,
+ A3: TypedEncoder,
+ A4: TypedEncoder,
+ A5: TypedEncoder,
+ R: TypedEncoder
+ ](f: (A1, A2, A3, A4, A5) => R
+ ): (
+ TypedColumn[T, A1],
+ TypedColumn[T, A2],
+ TypedColumn[T, A3],
+ TypedColumn[T, A4],
+ TypedColumn[T, A5]
+ ) => TypedColumn[T, R] = functions.udf(f)
+
+ /**
+ * Type-safe projection from type T to Tuple1[A]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A](
- ca: TypedColumn[T, A]
- ): TypedDataset[A] = {
+ ca: TypedColumn[T, A]
+ ): TypedDataset[A] = {
implicit val ea = ca.uencoder
val tuple1: TypedDataset[Tuple1[A]] = selectMany(ca)
@@ -753,10 +982,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
TypedEncoder[A].catalystRepr match {
case StructType(_) =>
// if column is struct, we use all its fields
- val df = tuple1
- .dataset
- .selectExpr("_1.*")
- .as[A](TypedExpressionEncoder[A])
+ val df =
+ tuple1.dataset.selectExpr("_1.*").as[A](TypedExpressionEncoder[A])
TypedDataset.create(df)
case other =>
@@ -765,217 +992,288 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
}
- /** Type-safe projection from type T to Tuple2[A,B]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple2[A,B]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B]
- ): TypedDataset[(A, B)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B]
+ ): TypedDataset[(A, B)] = {
implicit val (ea, eb) = (ca.uencoder, cb.uencoder)
selectMany(ca, cb)
}
- /** Type-safe projection from type T to Tuple3[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple3[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C]
- ): TypedDataset[(A, B, C)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C]
+ ): TypedDataset[(A, B, C)] = {
implicit val (ea, eb, ec) = (ca.uencoder, cb.uencoder, cc.uencoder)
selectMany(ca, cb, cc)
}
- /** Type-safe projection from type T to Tuple4[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple4[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C, D](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C],
- cd: TypedColumn[T, D]
- ): TypedDataset[(A, B, C, D)] = {
- implicit val (ea, eb, ec, ed) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder)
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C],
+ cd: TypedColumn[T, D]
+ ): TypedDataset[(A, B, C, D)] = {
+ implicit val (ea, eb, ec, ed) =
+ (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder)
selectMany(ca, cb, cc, cd)
}
- /** Type-safe projection from type T to Tuple5[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple5[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C, D, E](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C],
- cd: TypedColumn[T, D],
- ce: TypedColumn[T, E]
- ): TypedDataset[(A, B, C, D, E)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C],
+ cd: TypedColumn[T, D],
+ ce: TypedColumn[T, E]
+ ): TypedDataset[(A, B, C, D, E)] = {
implicit val (ea, eb, ec, ed, ee) =
(ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder)
selectMany(ca, cb, cc, cd, ce)
}
- /** Type-safe projection from type T to Tuple6[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple6[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C, D, E, F](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C],
- cd: TypedColumn[T, D],
- ce: TypedColumn[T, E],
- cf: TypedColumn[T, F]
- ): TypedDataset[(A, B, C, D, E, F)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C],
+ cd: TypedColumn[T, D],
+ ce: TypedColumn[T, E],
+ cf: TypedColumn[T, F]
+ ): TypedDataset[(A, B, C, D, E, F)] = {
implicit val (ea, eb, ec, ed, ee, ef) =
- (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder)
+ (
+ ca.uencoder,
+ cb.uencoder,
+ cc.uencoder,
+ cd.uencoder,
+ ce.uencoder,
+ cf.uencoder
+ )
selectMany(ca, cb, cc, cd, ce, cf)
}
- /** Type-safe projection from type T to Tuple7[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple7[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C, D, E, F, G](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C],
- cd: TypedColumn[T, D],
- ce: TypedColumn[T, E],
- cf: TypedColumn[T, F],
- cg: TypedColumn[T, G]
- ): TypedDataset[(A, B, C, D, E, F, G)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C],
+ cd: TypedColumn[T, D],
+ ce: TypedColumn[T, E],
+ cf: TypedColumn[T, F],
+ cg: TypedColumn[T, G]
+ ): TypedDataset[(A, B, C, D, E, F, G)] = {
implicit val (ea, eb, ec, ed, ee, ef, eg) =
- (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder)
+ (
+ ca.uencoder,
+ cb.uencoder,
+ cc.uencoder,
+ cd.uencoder,
+ ce.uencoder,
+ cf.uencoder,
+ cg.uencoder
+ )
selectMany(ca, cb, cc, cd, ce, cf, cg)
}
- /** Type-safe projection from type T to Tuple8[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple8[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C, D, E, F, G, H](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C],
- cd: TypedColumn[T, D],
- ce: TypedColumn[T, E],
- cf: TypedColumn[T, F],
- cg: TypedColumn[T, G],
- ch: TypedColumn[T, H]
- ): TypedDataset[(A, B, C, D, E, F, G, H)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C],
+ cd: TypedColumn[T, D],
+ ce: TypedColumn[T, E],
+ cf: TypedColumn[T, F],
+ cg: TypedColumn[T, G],
+ ch: TypedColumn[T, H]
+ ): TypedDataset[(A, B, C, D, E, F, G, H)] = {
implicit val (ea, eb, ec, ed, ee, ef, eg, eh) =
- (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder)
+ (
+ ca.uencoder,
+ cb.uencoder,
+ cc.uencoder,
+ cd.uencoder,
+ ce.uencoder,
+ cf.uencoder,
+ cg.uencoder,
+ ch.uencoder
+ )
selectMany(ca, cb, cc, cd, ce, cf, cg, ch)
}
- /** Type-safe projection from type T to Tuple9[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple9[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C, D, E, F, G, H, I](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C],
- cd: TypedColumn[T, D],
- ce: TypedColumn[T, E],
- cf: TypedColumn[T, F],
- cg: TypedColumn[T, G],
- ch: TypedColumn[T, H],
- ci: TypedColumn[T, I]
- ): TypedDataset[(A, B, C, D, E, F, G, H, I)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C],
+ cd: TypedColumn[T, D],
+ ce: TypedColumn[T, E],
+ cf: TypedColumn[T, F],
+ cg: TypedColumn[T, G],
+ ch: TypedColumn[T, H],
+ ci: TypedColumn[T, I]
+ ): TypedDataset[(A, B, C, D, E, F, G, H, I)] = {
implicit val (ea, eb, ec, ed, ee, ef, eg, eh, ei) =
- (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder, ci.uencoder)
+ (
+ ca.uencoder,
+ cb.uencoder,
+ cc.uencoder,
+ cd.uencoder,
+ ce.uencoder,
+ cf.uencoder,
+ cg.uencoder,
+ ch.uencoder,
+ ci.uencoder
+ )
selectMany(ca, cb, cc, cd, ce, cf, cg, ch, ci)
}
- /** Type-safe projection from type T to Tuple10[A,B,...]
- * {{{
- * d.select( d('a), d('a)+d('b), ... )
- * }}}
- */
+ /**
+ * Type-safe projection from type T to Tuple10[A,B,...]
+ * {{{
+ * d.select( d('a), d('a)+d('b), ... )
+ * }}}
+ */
def select[A, B, C, D, E, F, G, H, I, J](
- ca: TypedColumn[T, A],
- cb: TypedColumn[T, B],
- cc: TypedColumn[T, C],
- cd: TypedColumn[T, D],
- ce: TypedColumn[T, E],
- cf: TypedColumn[T, F],
- cg: TypedColumn[T, G],
- ch: TypedColumn[T, H],
- ci: TypedColumn[T, I],
- cj: TypedColumn[T, J]
- ): TypedDataset[(A, B, C, D, E, F, G, H, I, J)] = {
+ ca: TypedColumn[T, A],
+ cb: TypedColumn[T, B],
+ cc: TypedColumn[T, C],
+ cd: TypedColumn[T, D],
+ ce: TypedColumn[T, E],
+ cf: TypedColumn[T, F],
+ cg: TypedColumn[T, G],
+ ch: TypedColumn[T, H],
+ ci: TypedColumn[T, I],
+ cj: TypedColumn[T, J]
+ ): TypedDataset[(A, B, C, D, E, F, G, H, I, J)] = {
implicit val (ea, eb, ec, ed, ee, ef, eg, eh, ei, ej) =
- (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder, ci.uencoder, cj.uencoder)
+ (
+ ca.uencoder,
+ cb.uencoder,
+ cc.uencoder,
+ cd.uencoder,
+ ce.uencoder,
+ cf.uencoder,
+ cg.uencoder,
+ ch.uencoder,
+ ci.uencoder,
+ cj.uencoder
+ )
selectMany(ca, cb, cc, cd, ce, cf, cg, ch, ci, cj)
}
object selectMany extends ProductArgs {
- def applyProduct[U <: HList, Out0 <: HList, Out](columns: U)
- (implicit
+
+ def applyProduct[U <: HList, Out0 <: HList, Out](
+ columns: U
+ )(implicit
i0: ColumnTypes.Aux[T, U, Out0],
i1: ToTraversable.Aux[U, List, UntypedExpression[T]],
i2: Tupler.Aux[Out0, Out],
i3: TypedEncoder[Out]
): TypedDataset[Out] = {
- val base = dataset.toDF()
- .select(columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)):_*)
- val selected = base.as[Out](TypedExpressionEncoder[Out])
+ val base = dataset
+ .toDF()
+ .select(
+ columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)): _*
+ )
+ val selected = base.as[Out](TypedExpressionEncoder[Out])
- TypedDataset.create[Out](selected)
- }
+ TypedDataset.create[Out](selected)
+ }
}
/** Sort each partition in the dataset using the columns selected. */
- def sortWithinPartitions[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] =
+ def sortWithinPartitions[A: CatalystOrdered](
+ ca: SortedTypedColumn[T, A]
+ ): TypedDataset[T] =
sortWithinPartitionsMany(ca)
/** Sort each partition in the dataset using the columns selected. */
def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered](
- ca: SortedTypedColumn[T, A],
- cb: SortedTypedColumn[T, B]
- ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb)
+ ca: SortedTypedColumn[T, A],
+ cb: SortedTypedColumn[T, B]
+ ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb)
/** Sort each partition in the dataset using the columns selected. */
- def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered](
- ca: SortedTypedColumn[T, A],
- cb: SortedTypedColumn[T, B],
- cc: SortedTypedColumn[T, C]
- ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc)
-
- /** Sort each partition in the dataset by the given column expressions
- * Default sort order is ascending.
- * {{{
- * d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc)
- * }}}
- */
+ def sortWithinPartitions[
+ A: CatalystOrdered,
+ B: CatalystOrdered,
+ C: CatalystOrdered
+ ](ca: SortedTypedColumn[T, A],
+ cb: SortedTypedColumn[T, B],
+ cc: SortedTypedColumn[T, C]
+ ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc)
+
+ /**
+ * Sort each partition in the dataset by the given column expressions
+ * Default sort order is ascending.
+ * {{{
+ * d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc)
+ * }}}
+ */
object sortWithinPartitionsMany extends ProductArgs {
- def applyProduct[U <: HList, O <: HList](columns: U)
- (implicit
+
+ def applyProduct[U <: HList, O <: HList](
+ columns: U
+ )(implicit
i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O],
i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]]
): TypedDataset[T] = {
- val sorted = dataset.toDF()
- .sortWithinPartitions(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*)
+ val sorted = dataset
+ .toDF()
+ .sortWithinPartitions(
+ i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped): _*
+ )
.as[T](TypedExpressionEncoder[T])
TypedDataset.create[T](sorted)
@@ -983,273 +1281,316 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
/** Orders the TypedDataset using the column selected. */
- def orderBy[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] =
+ def orderBy[A: CatalystOrdered](
+ ca: SortedTypedColumn[T, A]
+ ): TypedDataset[T] =
orderByMany(ca)
/** Orders the TypedDataset using the columns selected. */
def orderBy[A: CatalystOrdered, B: CatalystOrdered](
- ca: SortedTypedColumn[T, A],
- cb: SortedTypedColumn[T, B]
- ): TypedDataset[T] = orderByMany(ca, cb)
-
- /** Orders the TypedDataset using the columns selected. */
- def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered](
- ca: SortedTypedColumn[T, A],
- cb: SortedTypedColumn[T, B],
- cc: SortedTypedColumn[T, C]
- ): TypedDataset[T] = orderByMany(ca, cb, cc)
-
- /** Sort the dataset by any number of column expressions.
- * Default sort order is ascending.
- * {{{
- * d.orderByMany(d('a), d('b).desc, d('c).asc)
- * }}}
- */
+ ca: SortedTypedColumn[T, A],
+ cb: SortedTypedColumn[T, B]
+ ): TypedDataset[T] = orderByMany(ca, cb)
+
+ /** Orders the TypedDataset using the columns selected. */
+ def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered](
+ ca: SortedTypedColumn[T, A],
+ cb: SortedTypedColumn[T, B],
+ cc: SortedTypedColumn[T, C]
+ ): TypedDataset[T] = orderByMany(ca, cb, cc)
+
+ /**
+ * Sort the dataset by any number of column expressions.
+ * Default sort order is ascending.
+ * {{{
+ * d.orderByMany(d('a), d('b).desc, d('c).asc)
+ * }}}
+ */
object orderByMany extends ProductArgs {
- def applyProduct[U <: HList, O <: HList](columns: U)
- (implicit
+
+ def applyProduct[U <: HList, O <: HList](
+ columns: U
+ )(implicit
i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O],
i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]]
): TypedDataset[T] = {
- val sorted = dataset.toDF()
- .orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*)
+ val sorted = dataset
+ .toDF()
+ .orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped): _*)
.as[T](TypedExpressionEncoder[T])
TypedDataset.create[T](sorted)
}
}
- /** Returns a new Dataset as a tuple with the specified
- * column dropped.
- * Does not allow for dropping from a single column TypedDataset
- *
- * {{{
- * val d: TypedDataset[Foo(a: String, b: Int...)] = ???
- * val result = TypedDataset[(Int, ...)] = d.drop('a)
- * }}}
- * @param column column to drop specified as a Symbol
- * @param i0 LabelledGeneric derived for T
- * @param i1 Remover derived for TRep and column
- * @param i2 values of T with column removed
- * @param i3 tupler of values
- * @param i4 evidence of encoder of the tupled values
- * @tparam Out Tupled return type
- * @tparam TRep shapeless' record representation of T
- * @tparam Removed record of T with column removed
- * @tparam ValuesFromRemoved values of T with column removed as an HList
- * @tparam V value type of column in T
- * @return
- */
- def dropTupled[Out, TRep <: HList, Removed <: HList, ValuesFromRemoved <: HList, V]
- (column: Witness.Lt[Symbol])
- (implicit
+ /**
+ * Returns a new Dataset as a tuple with the specified
+ * column dropped.
+ * Does not allow for dropping from a single column TypedDataset
+ *
+ * {{{
+ * val d: TypedDataset[Foo(a: String, b: Int...)] = ???
+ * val result = TypedDataset[(Int, ...)] = d.drop('a)
+ * }}}
+ * @param column column to drop specified as a Symbol
+ * @param i0 LabelledGeneric derived for T
+ * @param i1 Remover derived for TRep and column
+ * @param i2 values of T with column removed
+ * @param i3 tupler of values
+ * @param i4 evidence of encoder of the tupled values
+ * @tparam Out Tupled return type
+ * @tparam TRep shapeless' record representation of T
+ * @tparam Removed record of T with column removed
+ * @tparam ValuesFromRemoved values of T with column removed as an HList
+ * @tparam V value type of column in T
+ * @return
+ */
+ def dropTupled[
+ Out,
+ TRep <: HList,
+ Removed <: HList,
+ ValuesFromRemoved <: HList,
+ V
+ ](column: Witness.Lt[Symbol]
+ )(implicit
i0: LabelledGeneric.Aux[T, TRep],
i1: Remover.Aux[TRep, column.T, (V, Removed)],
i2: Values.Aux[Removed, ValuesFromRemoved],
i3: Tupler.Aux[ValuesFromRemoved, Out],
i4: TypedEncoder[Out]
): TypedDataset[Out] = {
- val dropped = dataset
- .toDF()
- .drop(column.value.name)
- .as[Out](TypedExpressionEncoder[Out])
+ val dropped = dataset
+ .toDF()
+ .drop(column.value.name)
+ .as[Out](TypedExpressionEncoder[Out])
- TypedDataset.create[Out](dropped)
- }
+ TypedDataset.create[Out](dropped)
+ }
/**
- * Drops columns as necessary to return `U`
- *
- * @example
- * {{{
- * case class X(i: Int, j: Int, k: Boolean)
- * case class Y(i: Int, k: Boolean)
- * val f: TypedDataset[X] = ???
- * val fNew: TypedDataset[Y] = f.drop[Y]
- * }}}
- *
- * @tparam U the output type
- *
- * @see [[frameless.TypedDataset#project]]
- */
- def drop[U](implicit projector: SmartProject[T,U]): TypedDataset[U] = project[U]
-
- /** Prepends a new column to the Dataset.
- *
- * {{{
- * case class X(i: Int, j: Int)
- * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
- * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10)
- * }}}
- */
- def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out]
- (ca: TypedColumn[T, A])
- (implicit
+ * Drops columns as necessary to return `U`
+ *
+ * @example
+ * {{{
+ * case class X(i: Int, j: Int, k: Boolean)
+ * case class Y(i: Int, k: Boolean)
+ * val f: TypedDataset[X] = ???
+ * val fNew: TypedDataset[Y] = f.drop[Y]
+ * }}}
+ *
+ * @tparam U the output type
+ *
+ * @see [[frameless.TypedDataset#project]]
+ */
+ def drop[U](
+ implicit
+ projector: SmartProject[T, U]
+ ): TypedDataset[U] = project[U]
+
+ /**
+ * Prepends a new column to the Dataset.
+ *
+ * {{{
+ * case class X(i: Int, j: Int)
+ * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
+ * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10)
+ * }}}
+ */
+ def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out](
+ ca: TypedColumn[T, A]
+ )(implicit
i0: Generic.Aux[T, H],
i1: Prepend.Aux[H, A :: HNil, FH],
i2: Tupler.Aux[FH, Out],
i3: TypedEncoder[Out]
): TypedDataset[Out] = {
- // Giving a random name to the new column (the proper name will be given by the Tuple-based encoder)
- val selected = dataset.toDF().withColumn("I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI", ca.untyped)
- .as[Out](TypedExpressionEncoder[Out])
+ // Giving a random name to the new column (the proper name will be given by the Tuple-based encoder)
+ val selected = dataset
+ .toDF()
+ .withColumn("I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI", ca.untyped)
+ .as[Out](TypedExpressionEncoder[Out])
- TypedDataset.create[Out](selected)
+ TypedDataset.create[Out](selected)
}
- /** Returns a new [[frameless.TypedDataset]] with the specified column updated with a new value
- * {{{
- * case class X(i: Int, j: Int)
- * val f: TypedDataset[X] = TypedDataset.create(X(1,10) :: Nil)
- * val fNew: TypedDataset[X] = f.withColumn('j, f('i)) // results in X(1, 1) :: Nil
- * }}}
- * @param column column given as a symbol to replace
- * @param replacement column to replace the value with
- * @param i0 Evidence that a column with the correct type and name exists
- */
+ /**
+ * Returns a new [[frameless.TypedDataset]] with the specified column updated with a new value
+ * {{{
+ * case class X(i: Int, j: Int)
+ * val f: TypedDataset[X] = TypedDataset.create(X(1,10) :: Nil)
+ * val fNew: TypedDataset[X] = f.withColumn('j, f('i)) // results in X(1, 1) :: Nil
+ * }}}
+ * @param column column given as a symbol to replace
+ * @param replacement column to replace the value with
+ * @param i0 Evidence that a column with the correct type and name exists
+ */
def withColumnReplaced[A](
- column: Witness.Lt[Symbol],
- replacement: TypedColumn[T, A]
- )(implicit
- i0: TypedColumn.Exists[T, column.T, A]
- ): TypedDataset[T] = {
- val updated = dataset.toDF().withColumn(column.value.name, replacement.untyped)
+ column: Witness.Lt[Symbol],
+ replacement: TypedColumn[T, A]
+ )(implicit
+ i0: TypedColumn.Exists[T, column.T, A]
+ ): TypedDataset[T] = {
+ val updated = dataset
+ .toDF()
+ .withColumn(column.value.name, replacement.untyped)
.as[T](TypedExpressionEncoder[T])
TypedDataset.create[T](updated)
}
- /** Adds a column to a Dataset so long as the specified output type, `U`, has
- * an extra column from `T` that has type `A`.
- *
- * @example
- * {{{
- * case class X(i: Int, j: Int)
- * case class Y(i: Int, j: Int, k: Boolean)
- * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
- * val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10)
- * }}}
- * @param ca The typed column to add
- * @param i0 TypeEncoder for output type U
- * @param i1 TypeEncoder for added column type A
- * @param i2 the LabelledGeneric derived for T
- * @param i3 the LabelledGeneric derived for U
- * @param i4 proof no fields have been removed
- * @param i5 diff from T to U
- * @param i6 keys from newFields
- * @param i7 the one and only new key
- * @param i8 the one and only new field enforcing the type of A exists
- * @param i9 the keys of U
- * @param iA allows for traversing the keys of U
- * @tparam U the output type
- * @tparam A The added column type
- * @tparam TRep shapeless' record representation of T
- * @tparam URep shapeless' record representation of U
- * @tparam UKeys the keys of U as an HList
- * @tparam NewFields the added fields to T to get U
- * @tparam NewKeys the keys of NewFields as an HList
- * @tparam NewKey the first, and only, key in NewKey
- *
- * @see [[frameless.TypedDataset.WithColumnApply#apply]]
- */
+ /**
+ * Adds a column to a Dataset so long as the specified output type, `U`, has
+ * an extra column from `T` that has type `A`.
+ *
+ * @example
+ * {{{
+ * case class X(i: Int, j: Int)
+ * case class Y(i: Int, j: Int, k: Boolean)
+ * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
+ * val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10)
+ * }}}
+ * @param ca The typed column to add
+ * @param i0 TypeEncoder for output type U
+ * @param i1 TypeEncoder for added column type A
+ * @param i2 the LabelledGeneric derived for T
+ * @param i3 the LabelledGeneric derived for U
+ * @param i4 proof no fields have been removed
+ * @param i5 diff from T to U
+ * @param i6 keys from newFields
+ * @param i7 the one and only new key
+ * @param i8 the one and only new field enforcing the type of A exists
+ * @param i9 the keys of U
+ * @param iA allows for traversing the keys of U
+ * @tparam U the output type
+ * @tparam A The added column type
+ * @tparam TRep shapeless' record representation of T
+ * @tparam URep shapeless' record representation of U
+ * @tparam UKeys the keys of U as an HList
+ * @tparam NewFields the added fields to T to get U
+ * @tparam NewKeys the keys of NewFields as an HList
+ * @tparam NewKey the first, and only, key in NewKey
+ *
+ * @see [[frameless.TypedDataset.WithColumnApply#apply]]
+ */
def withColumn[U] = new WithColumnApply[U]
class WithColumnApply[U] {
- def apply[A, TRep <: HList, URep <: HList, UKeys <: HList, NewFields <: HList, NewKeys <: HList, NewKey <: Symbol]
- (ca: TypedColumn[T, A])
- (implicit
- i0: TypedEncoder[U],
- i1: TypedEncoder[A],
- i2: LabelledGeneric.Aux[T, TRep],
- i3: LabelledGeneric.Aux[U, URep],
- i4: Diff.Aux[TRep, URep, HNil],
- i5: Diff.Aux[URep, TRep, NewFields],
- i6: Keys.Aux[NewFields, NewKeys],
- i7: IsHCons.Aux[NewKeys, NewKey, HNil],
- i8: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil],
- i9: Keys.Aux[URep, UKeys],
- iA: ToTraversable.Aux[UKeys, Seq, Symbol]
- ): TypedDataset[U] = {
+
+ def apply[
+ A,
+ TRep <: HList,
+ URep <: HList,
+ UKeys <: HList,
+ NewFields <: HList,
+ NewKeys <: HList,
+ NewKey <: Symbol
+ ](ca: TypedColumn[T, A]
+ )(implicit
+ i0: TypedEncoder[U],
+ i1: TypedEncoder[A],
+ i2: LabelledGeneric.Aux[T, TRep],
+ i3: LabelledGeneric.Aux[U, URep],
+ i4: Diff.Aux[TRep, URep, HNil],
+ i5: Diff.Aux[URep, TRep, NewFields],
+ i6: Keys.Aux[NewFields, NewKeys],
+ i7: IsHCons.Aux[NewKeys, NewKey, HNil],
+ i8: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil],
+ i9: Keys.Aux[URep, UKeys],
+ iA: ToTraversable.Aux[UKeys, Seq, Symbol]
+ ): TypedDataset[U] = {
val newColumnName =
i7.head(i6()).name
- val dfWithNewColumn = dataset
- .toDF()
- .withColumn(newColumnName, ca.untyped)
+ val dfWithNewColumn = dataset.toDF().withColumn(newColumnName, ca.untyped)
val newColumns = i9.apply().to[Seq].map(_.name).map(dfWithNewColumn.col)
- val selected = dfWithNewColumn
- .select(newColumns: _*)
- .as[U](TypedExpressionEncoder[U])
+ val selected =
+ dfWithNewColumn.select(newColumns: _*).as[U](TypedExpressionEncoder[U])
TypedDataset.create[U](selected)
}
}
/**
- * Explodes a single column at a time. It only compiles if the type of column supports this operation.
- *
- * @example
- *
- * {{{
- * case class X(i: Int, j: Array[Int])
- * case class Y(i: Int, j: Int)
- *
- * val f: TypedDataset[X] = ???
- * val fNew: TypedDataset[Y] = f.explode('j).as[Y]
- * }}}
- * @param column the column we wish to explode
- */
- def explode[A, TRep <: HList, V[_], OutMod <: HList, OutModValues <: HList, Out]
- (column: Witness.Lt[Symbol])
- (implicit
- i0: TypedColumn.Exists[T, column.T, V[A]],
- i1: TypedEncoder[A],
- i2: CatalystExplodableCollection[V],
- i3: LabelledGeneric.Aux[T, TRep],
- i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod],
- i5: Values.Aux[OutMod, OutModValues],
- i6: Tupler.Aux[OutModValues, Out],
- i7: TypedEncoder[Out]
- ): TypedDataset[Out] = {
- import org.apache.spark.sql.functions.{explode => sparkExplode}
+ * Explodes a single column at a time. It only compiles if the type of column supports this operation.
+ *
+ * @example
+ *
+ * {{{
+ * case class X(i: Int, j: Array[Int])
+ * case class Y(i: Int, j: Int)
+ *
+ * val f: TypedDataset[X] = ???
+ * val fNew: TypedDataset[Y] = f.explode('j).as[Y]
+ * }}}
+ * @param column the column we wish to explode
+ */
+ def explode[
+ A,
+ TRep <: HList,
+ V[_],
+ OutMod <: HList,
+ OutModValues <: HList,
+ Out
+ ](column: Witness.Lt[Symbol]
+ )(implicit
+ i0: TypedColumn.Exists[T, column.T, V[A]],
+ i1: TypedEncoder[A],
+ i2: CatalystExplodableCollection[V],
+ i3: LabelledGeneric.Aux[T, TRep],
+ i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod],
+ i5: Values.Aux[OutMod, OutModValues],
+ i6: Tupler.Aux[OutModValues, Out],
+ i7: TypedEncoder[Out]
+ ): TypedDataset[Out] = {
+ import org.apache.spark.sql.functions.{ explode => sparkExplode }
val df = dataset.toDF()
val trans =
- df
- .withColumn(column.value.name, sparkExplode(df(column.value.name)))
+ df.withColumn(column.value.name, sparkExplode(df(column.value.name)))
.as[Out](TypedExpressionEncoder[Out])
TypedDataset.create[Out](trans)
}
/**
- * Explodes a single column at a time. It only compiles if the type of column supports this operation.
- *
- * @example
- *
- * {{{
- * case class X(i: Int, j: Map[Int, Int])
- * case class Y(i: Int, j: (Int, Int))
- *
- * val f: TypedDataset[X] = ???
- * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y]
- * }}}
- * @param column the column we wish to explode
- */
- def explodeMap[A, B, V[_, _], TRep <: HList, OutMod <: HList, OutModValues <: HList, Out]
- (column: Witness.Lt[Symbol])
- (implicit
- i0: TypedColumn.Exists[T, column.T, V[A, B]],
- i1: TypedEncoder[A],
- i2: TypedEncoder[B],
- i3: LabelledGeneric.Aux[T, TRep],
- i4: Modifier.Aux[TRep, column.T, V[A,B], (A, B), OutMod],
- i5: Values.Aux[OutMod, OutModValues],
- i6: Tupler.Aux[OutModValues, Out],
- i7: TypedEncoder[Out]
- ): TypedDataset[Out] = {
- import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol}
+ * Explodes a single column at a time. It only compiles if the type of column supports this operation.
+ *
+ * @example
+ *
+ * {{{
+ * case class X(i: Int, j: Map[Int, Int])
+ * case class Y(i: Int, j: (Int, Int))
+ *
+ * val f: TypedDataset[X] = ???
+ * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y]
+ * }}}
+ * @param column the column we wish to explode
+ */
+ def explodeMap[
+ A,
+ B,
+ V[_, _],
+ TRep <: HList,
+ OutMod <: HList,
+ OutModValues <: HList,
+ Out
+ ](column: Witness.Lt[Symbol]
+ )(implicit
+ i0: TypedColumn.Exists[T, column.T, V[A, B]],
+ i1: TypedEncoder[A],
+ i2: TypedEncoder[B],
+ i3: LabelledGeneric.Aux[T, TRep],
+ i4: Modifier.Aux[TRep, column.T, V[A, B], (A, B), OutMod],
+ i5: Values.Aux[OutMod, OutModValues],
+ i6: Tupler.Aux[OutModValues, Out],
+ i7: TypedEncoder[Out]
+ ): TypedDataset[Out] = {
+ import org.apache.spark.sql.functions.{
+ explode => sparkExplode,
+ struct => sparkStruct,
+ col => sparkCol
+ }
val df = dataset.toDF()
// select all columns, all original columns and [key, value] columns appeared after the map explode
@@ -1271,7 +1612,10 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
exploded
// map explode explodes it into [key, value] columns
// the only way to put it into a column is to create a struct
- .withColumn(columnRenamed, sparkStruct(exploded("key"), exploded("value")))
+ .withColumn(
+ columnRenamed,
+ sparkStruct(exploded("key"), exploded("value"))
+ )
// selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode
.select(columns: _*)
// rename columns back and form the result
@@ -1281,72 +1625,81 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
/**
- * Flattens a column of type Option[A]. Compiles only if the selected column is of type Option[A].
- *
- *
- * @example
- *
- * {{{
- * case class X(i: Int, j: Option[Int])
- * case class Y(i: Int, j: Int)
- *
- * val f: TypedDataset[X] = ???
- * val fNew: TypedDataset[Y] = f.flattenOption('j).as[Y]
- * }}}
- *
- * @param column the column we wish to flatten
- */
- def flattenOption[A, TRep <: HList, V[_], OutMod <: HList, OutModValues <: HList, Out]
- (column: Witness.Lt[Symbol])
- (implicit
- i0: TypedColumn.Exists[T, column.T, V[A]],
- i1: TypedEncoder[A],
- i2: V[A] =:= Option[A],
- i3: LabelledGeneric.Aux[T, TRep],
- i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod],
- i5: Values.Aux[OutMod, OutModValues],
- i6: Tupler.Aux[OutModValues, Out],
- i7: TypedEncoder[Out]
- ): TypedDataset[Out] = {
+ * Flattens a column of type Option[A]. Compiles only if the selected column is of type Option[A].
+ *
+ * @example
+ *
+ * {{{
+ * case class X(i: Int, j: Option[Int])
+ * case class Y(i: Int, j: Int)
+ *
+ * val f: TypedDataset[X] = ???
+ * val fNew: TypedDataset[Y] = f.flattenOption('j).as[Y]
+ * }}}
+ *
+ * @param column the column we wish to flatten
+ */
+ def flattenOption[
+ A,
+ TRep <: HList,
+ V[_],
+ OutMod <: HList,
+ OutModValues <: HList,
+ Out
+ ](column: Witness.Lt[Symbol]
+ )(implicit
+ i0: TypedColumn.Exists[T, column.T, V[A]],
+ i1: TypedEncoder[A],
+ i2: V[A] =:= Option[A],
+ i3: LabelledGeneric.Aux[T, TRep],
+ i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod],
+ i5: Values.Aux[OutMod, OutModValues],
+ i6: Tupler.Aux[OutModValues, Out],
+ i7: TypedEncoder[Out]
+ ): TypedDataset[Out] = {
val df = dataset.toDF()
- val trans = df.filter(df(column.value.name).isNotNull).
- as[Out](TypedExpressionEncoder[Out])
+ val trans = df
+ .filter(df(column.value.name).isNotNull)
+ .as[Out](TypedExpressionEncoder[Out])
TypedDataset.create[Out](trans)
}
}
object TypedDataset {
- def create[A](data: Seq[A])
- (implicit
+
+ def create[A](
+ data: Seq[A]
+ )(implicit
encoder: TypedEncoder[A],
sqlContext: SparkSession
): TypedDataset[A] = {
- val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])
+ val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])
- TypedDataset.create[A](dataset)
- }
+ TypedDataset.create[A](dataset)
+ }
- def create[A](data: RDD[A])
- (implicit
+ def create[A](
+ data: RDD[A]
+ )(implicit
encoder: TypedEncoder[A],
sqlContext: SparkSession
): TypedDataset[A] = {
- val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])
+ val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A])
- TypedDataset.create[A](dataset)
- }
+ TypedDataset.create[A](dataset)
+ }
def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] =
createUnsafe(dataset.toDF())
/**
- * Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]].
- * Note that the names and types need to align!
- *
- * This is an unsafe operation: If the schemas do not align,
- * the error will be captured at runtime (not during compilation).
- */
+ * Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]].
+ * Note that the names and types need to align!
+ *
+ * This is an unsafe operation: If the schemas do not align,
+ * the error will be captured at runtime (not during compilation).
+ */
def createUnsafe[A: TypedEncoder](df: DataFrame): TypedDataset[A] = {
val e = TypedEncoder[A]
val output: Seq[Attribute] = df.queryExecution.analyzed.output
@@ -1358,7 +1711,8 @@ object TypedDataset {
throw new IllegalStateException(
s"Unsupported creation of TypedDataset with ${targetFields.size} column(s) " +
s"from a DataFrame with ${output.size} columns. " +
- "Try to `select()` the proper columns in the right order before calling `create()`.")
+ "Try to `select()` the proper columns in the right order before calling `create()`."
+ )
}
// Adapt names if they are not the same (note: types still might not match)
@@ -1368,7 +1722,7 @@ object TypedDataset {
val canSelect = targetColNames.toSet.subsetOf(output.map(_.name).toSet)
val reshaped = if (shouldReshape && canSelect) {
- df.select(targetColNames.head, targetColNames.tail:_*)
+ df.select(targetColNames.head, targetColNames.tail: _*)
} else if (shouldReshape) {
df.toDF(targetColNames: _*)
} else {
@@ -1378,9 +1732,14 @@ object TypedDataset {
new TypedDataset[A](reshaped.as[A](TypedExpressionEncoder[A]))
}
- /** Prefer `TypedDataset.create` over `TypedDataset.unsafeCreate` unless you
- * know what you are doing. */
- @deprecated("Prefer TypedDataset.create over TypedDataset.unsafeCreate", "0.3.0")
+ /**
+ * Prefer `TypedDataset.create` over `TypedDataset.unsafeCreate` unless you
+ * know what you are doing.
+ */
+ @deprecated(
+ "Prefer TypedDataset.create over TypedDataset.unsafeCreate",
+ "0.3.0"
+ )
def unsafeCreate[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = {
new TypedDataset[A](dataset)
}
diff --git a/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala b/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala
index d417caf8e..b4beac7bf 100644
--- a/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala
+++ b/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala
@@ -6,366 +6,428 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.{DataFrame, DataFrameWriter, SQLContext, SparkSession}
+import org.apache.spark.sql.{
+ DataFrame,
+ DataFrameWriter,
+ SQLContext,
+ SparkSession
+}
import org.apache.spark.storage.StorageLevel
import scala.util.Random
-/** This trait implements [[TypedDataset]] methods that have the same signature
- * than their `Dataset` equivalent. Each method simply forwards the call to the
- * underlying `Dataset`.
- *
- * Documentation marked "apache/spark" is thanks to apache/spark Contributors
- * at https://github.com/apache/spark, licensed under Apache v2.0 available at
- * http://www.apache.org/licenses/LICENSE-2.0
- */
+/**
+ * This trait implements [[TypedDataset]] methods that have the same signature
+ * than their `Dataset` equivalent. Each method simply forwards the call to the
+ * underlying `Dataset`.
+ *
+ * Documentation marked "apache/spark" is thanks to apache/spark Contributors
+ * at https://github.com/apache/spark, licensed under Apache v2.0 available at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
trait TypedDatasetForwarded[T] { self: TypedDataset[T] =>
override def toString: String =
dataset.toString
/**
- * Returns a `SparkSession` from this [[TypedDataset]].
- */
+ * Returns a `SparkSession` from this [[TypedDataset]].
+ */
def sparkSession: SparkSession =
dataset.sparkSession
/**
- * Returns a `SQLContext` from this [[TypedDataset]].
- */
+ * Returns a `SQLContext` from this [[TypedDataset]].
+ */
def sqlContext: SQLContext =
dataset.sqlContext
/**
- * Returns the schema of this Dataset.
- *
- * apache/spark
- */
+ * Returns the schema of this Dataset.
+ *
+ * apache/spark
+ */
def schema: StructType =
dataset.schema
- /** Prints the schema of the underlying `Dataset` to the console in a nice tree format.
- *
- * apache/spark
+ /**
+ * Prints the schema of the underlying `Dataset` to the console in a nice tree format.
+ *
+ * apache/spark
*/
def printSchema(): Unit =
dataset.printSchema()
- /** Prints the plans (logical and physical) to the console for debugging purposes.
- *
- * apache/spark
+ /**
+ * Prints the plans (logical and physical) to the console for debugging purposes.
+ *
+ * apache/spark
*/
def explain(extended: Boolean = false): Unit =
dataset.explain(extended)
/**
- * Returns a `QueryExecution` from this [[TypedDataset]].
- *
- * It is the primary workflow for executing relational queries using Spark. Designed to allow easy
- * access to the intermediate phases of query execution for developers.
- *
- * apache/spark
- */
+ * Returns a `QueryExecution` from this [[TypedDataset]].
+ *
+ * It is the primary workflow for executing relational queries using Spark. Designed to allow easy
+ * access to the intermediate phases of query execution for developers.
+ *
+ * apache/spark
+ */
def queryExecution: QueryExecution =
dataset.queryExecution
- /** Converts this strongly typed collection of data to generic Dataframe. In contrast to the
- * strongly typed objects that Dataset operations work on, a Dataframe returns generic Row
- * objects that allow fields to be accessed by ordinal or name.
- *
- * apache/spark
- */
+ /**
+ * Converts this strongly typed collection of data to generic Dataframe. In contrast to the
+ * strongly typed objects that Dataset operations work on, a Dataframe returns generic Row
+ * objects that allow fields to be accessed by ordinal or name.
+ *
+ * apache/spark
+ */
def toDF(): DataFrame =
dataset.toDF()
- /** Converts this [[TypedDataset]] to an RDD.
- *
- * apache/spark
- */
+ /**
+ * Converts this [[TypedDataset]] to an RDD.
+ *
+ * apache/spark
+ */
def rdd: RDD[T] =
dataset.rdd
- /** Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions.
- *
- * apache/spark
- */
+ /**
+ * Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions.
+ *
+ * apache/spark
+ */
def repartition(numPartitions: Int): TypedDataset[T] =
TypedDataset.create(dataset.repartition(numPartitions))
-
/**
- * Get the [[TypedDataset]]'s current storage level, or StorageLevel.NONE if not persisted.
- *
- * apache/spark
- */
+ * Get the [[TypedDataset]]'s current storage level, or StorageLevel.NONE if not persisted.
+ *
+ * apache/spark
+ */
def storageLevel(): StorageLevel =
dataset.storageLevel
/**
- * Returns the content of the [[TypedDataset]] as a Dataset of JSON strings.
- *
- * apache/spark
- */
+ * Returns the content of the [[TypedDataset]] as a Dataset of JSON strings.
+ *
+ * apache/spark
+ */
def toJSON: TypedDataset[String] =
TypedDataset.create(dataset.toJSON)
/**
- * Interface for saving the content of the non-streaming [[TypedDataset]] out into external storage.
- *
- * apache/spark
- */
+ * Interface for saving the content of the non-streaming [[TypedDataset]] out into external storage.
+ *
+ * apache/spark
+ */
def write: DataFrameWriter[T] =
dataset.write
/**
- * Interface for saving the content of the streaming Dataset out into external storage.
- *
- * apache/spark
- */
+ * Interface for saving the content of the streaming Dataset out into external storage.
+ *
+ * apache/spark
+ */
def writeStream: DataStreamWriter[T] =
dataset.writeStream
-
- /** Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions.
- * Similar to coalesce defined on an RDD, this operation results in a narrow dependency, e.g.
- * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
- * the 100 new partitions will claim 10 of the current partitions.
- *
- * apache/spark
- */
+
+ /**
+ * Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions.
+ * Similar to coalesce defined on an RDD, this operation results in a narrow dependency, e.g.
+ * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
+ * the 100 new partitions will claim 10 of the current partitions.
+ *
+ * apache/spark
+ */
def coalesce(numPartitions: Int): TypedDataset[T] =
TypedDataset.create(dataset.coalesce(numPartitions))
/**
- * Returns an `Array` that contains all column names in this [[TypedDataset]].
- */
+ * Returns an `Array` that contains all column names in this [[TypedDataset]].
+ */
def columns: Array[String] =
dataset.columns
- /** Concise syntax for chaining custom transformations.
- *
- * apache/spark
- */
+ /**
+ * Concise syntax for chaining custom transformations.
+ *
+ * apache/spark
+ */
def transform[U](t: TypedDataset[T] => TypedDataset[U]): TypedDataset[U] =
t(this)
- /** Returns a new Dataset by taking the first `n` rows. The difference between this function
- * and `head` is that `head` is an action and returns an array (by triggering query execution)
- * while `limit` returns a new Dataset.
- *
- * apache/spark
- */
+ /**
+ * Returns a new Dataset by taking the first `n` rows. The difference between this function
+ * and `head` is that `head` is an action and returns an array (by triggering query execution)
+ * while `limit` returns a new Dataset.
+ *
+ * apache/spark
+ */
def limit(n: Int): TypedDataset[T] =
TypedDataset.create(dataset.limit(n))
- /** Returns a new [[TypedDataset]] by sampling a fraction of records.
- *
- * apache/spark
- */
- def sample(withReplacement: Boolean, fraction: Double, seed: Long = Random.nextLong()): TypedDataset[T] =
+ /**
+ * Returns a new [[TypedDataset]] by sampling a fraction of records.
+ *
+ * apache/spark
+ */
+ def sample(
+ withReplacement: Boolean,
+ fraction: Double,
+ seed: Long = Random.nextLong()
+ ): TypedDataset[T] =
TypedDataset.create(dataset.sample(withReplacement, fraction, seed))
- /** Returns a new [[TypedDataset]] that contains only the unique elements of this [[TypedDataset]].
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- *
- * apache/spark
- */
+ /**
+ * Returns a new [[TypedDataset]] that contains only the unique elements of this [[TypedDataset]].
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ *
+ * apache/spark
+ */
def distinct: TypedDataset[T] =
TypedDataset.create(dataset.distinct())
/**
- * Returns a best-effort snapshot of the files that compose this [[TypedDataset]]. This method simply
- * asks each constituent BaseRelation for its respective files and takes the union of all results.
- * Depending on the source relations, this may not find all input files. Duplicates are removed.
- *
- * apache/spark
- */
+ * Returns a best-effort snapshot of the files that compose this [[TypedDataset]]. This method simply
+ * asks each constituent BaseRelation for its respective files and takes the union of all results.
+ * Depending on the source relations, this may not find all input files. Duplicates are removed.
+ *
+ * apache/spark
+ */
def inputFiles: Array[String] =
dataset.inputFiles
/**
- * Returns true if the `collect` and `take` methods can be run locally
- * (without any Spark executors).
- *
- * apache/spark
- */
+ * Returns true if the `collect` and `take` methods can be run locally
+ * (without any Spark executors).
+ *
+ * apache/spark
+ */
def isLocal: Boolean =
dataset.isLocal
/**
- * Returns true if this [[TypedDataset]] contains one or more sources that continuously
- * return data as it arrives. A [[TypedDataset]] that reads data from a streaming source
- * must be executed as a `StreamingQuery` using the `start()` method in
- * `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or
- * `collect()`, will throw an `AnalysisException` when there is a streaming
- * source present.
- *
- * apache/spark
- */
+ * Returns true if this [[TypedDataset]] contains one or more sources that continuously
+ * return data as it arrives. A [[TypedDataset]] that reads data from a streaming source
+ * must be executed as a `StreamingQuery` using the `start()` method in
+ * `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or
+ * `collect()`, will throw an `AnalysisException` when there is a streaming
+ * source present.
+ *
+ * apache/spark
+ */
def isStreaming: Boolean =
dataset.isStreaming
- /** Returns a new [[TypedDataset]] that contains only the elements of this [[TypedDataset]] that are also
- * present in `other`.
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- *
- * apache/spark
- */
+ /**
+ * Returns a new [[TypedDataset]] that contains only the elements of this [[TypedDataset]] that are also
+ * present in `other`.
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ *
+ * apache/spark
+ */
def intersect(other: TypedDataset[T]): TypedDataset[T] =
TypedDataset.create(dataset.intersect(other.dataset))
/**
- * Randomly splits this [[TypedDataset]] with the provided weights.
- * Weights for splits, will be normalized if they don't sum to 1.
- *
- * apache/spark
- */
+ * Randomly splits this [[TypedDataset]] with the provided weights.
+ * Weights for splits, will be normalized if they don't sum to 1.
+ *
+ * apache/spark
+ */
// $COVERAGE-OFF$ We can not test this method because it is non-deterministic.
def randomSplit(weights: Array[Double]): Array[TypedDataset[T]] =
dataset.randomSplit(weights).map(TypedDataset.create[T])
// $COVERAGE-ON$
/**
- * Randomly splits this [[TypedDataset]] with the provided weights.
- * Weights for splits, will be normalized if they don't sum to 1.
- *
- * apache/spark
- */
+ * Randomly splits this [[TypedDataset]] with the provided weights.
+ * Weights for splits, will be normalized if they don't sum to 1.
+ *
+ * apache/spark
+ */
def randomSplit(weights: Array[Double], seed: Long): Array[TypedDataset[T]] =
dataset.randomSplit(weights, seed).map(TypedDataset.create[T])
/**
- * Returns a Java list that contains randomly split [[TypedDataset]] with the provided weights.
- * Weights for splits, will be normalized if they don't sum to 1.
- *
- * apache/spark
- */
- def randomSplitAsList(weights: Array[Double], seed: Long): util.List[TypedDataset[T]] = {
+ * Returns a Java list that contains randomly split [[TypedDataset]] with the provided weights.
+ * Weights for splits, will be normalized if they don't sum to 1.
+ *
+ * apache/spark
+ */
+ def randomSplitAsList(
+ weights: Array[Double],
+ seed: Long
+ ): util.List[TypedDataset[T]] = {
val values = randomSplit(weights, seed)
java.util.Arrays.asList(values: _*)
}
-
- /** Returns a new Dataset containing rows in this Dataset but not in another Dataset.
- * This is equivalent to `EXCEPT` in SQL.
- *
- * Note that, equality checking is performed directly on the encoded representation of the data
- * and thus is not affected by a custom `equals` function defined on `T`.
- *
- * apache/spark
- */
+ /**
+ * Returns a new Dataset containing rows in this Dataset but not in another Dataset.
+ * This is equivalent to `EXCEPT` in SQL.
+ *
+ * Note that, equality checking is performed directly on the encoded representation of the data
+ * and thus is not affected by a custom `equals` function defined on `T`.
+ *
+ * apache/spark
+ */
def except(other: TypedDataset[T]): TypedDataset[T] =
TypedDataset.create(dataset.except(other.dataset))
- /** Persist this [[TypedDataset]] with the default storage level (`MEMORY_AND_DISK`).
- *
- * apache/spark
- */
+ /**
+ * Persist this [[TypedDataset]] with the default storage level (`MEMORY_AND_DISK`).
+ *
+ * apache/spark
+ */
def cache(): TypedDataset[T] =
TypedDataset.create(dataset.cache())
- /** Persist this [[TypedDataset]] with the given storage level.
- * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
- * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc.
- *
- * apache/spark
- */
- def persist(newLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK): TypedDataset[T] =
+ /**
+ * Persist this [[TypedDataset]] with the given storage level.
+ * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`,
+ * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc.
+ *
+ * apache/spark
+ */
+ def persist(
+ newLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK
+ ): TypedDataset[T] =
TypedDataset.create(dataset.persist(newLevel))
- /** Mark the [[TypedDataset]] as non-persistent, and remove all blocks for it from memory and disk.
- * @param blocking Whether to block until all blocks are deleted.
- *
- * apache/spark
- */
+ /**
+ * Mark the [[TypedDataset]] as non-persistent, and remove all blocks for it from memory and disk.
+ * @param blocking Whether to block until all blocks are deleted.
+ *
+ * apache/spark
+ */
def unpersist(blocking: Boolean = false): TypedDataset[T] =
TypedDataset.create(dataset.unpersist(blocking))
// $COVERAGE-OFF$ We do not test deprecated method since forwarded methods are tested.
- @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0")
+ @deprecated(
+ "deserialized methods have moved to a separate section to highlight their runtime overhead",
+ "0.4.0"
+ )
def map[U: TypedEncoder](func: T => U): TypedDataset[U] =
deserialized.map(func)
- @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0")
- def mapPartitions[U: TypedEncoder](func: Iterator[T] => Iterator[U]): TypedDataset[U] =
+ @deprecated(
+ "deserialized methods have moved to a separate section to highlight their runtime overhead",
+ "0.4.0"
+ )
+ def mapPartitions[U: TypedEncoder](
+ func: Iterator[T] => Iterator[U]
+ ): TypedDataset[U] =
deserialized.mapPartitions(func)
- @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0")
+ @deprecated(
+ "deserialized methods have moved to a separate section to highlight their runtime overhead",
+ "0.4.0"
+ )
def flatMap[U: TypedEncoder](func: T => TraversableOnce[U]): TypedDataset[U] =
deserialized.flatMap(func)
- @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0")
+ @deprecated(
+ "deserialized methods have moved to a separate section to highlight their runtime overhead",
+ "0.4.0"
+ )
def filter(func: T => Boolean): TypedDataset[T] =
deserialized.filter(func)
- @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0")
+ @deprecated(
+ "deserialized methods have moved to a separate section to highlight their runtime overhead",
+ "0.4.0"
+ )
def reduceOption[F[_]: SparkDelay](func: (T, T) => T): F[Option[T]] =
deserialized.reduceOption(func)
// $COVERAGE-ON$
- /** Methods on `TypedDataset[T]` that go through a full serialization and
- * deserialization of `T`, and execute outside of the Catalyst runtime.
- *
- * @example The correct way to do a projection on a single column is to
- * use the `select` method as follows:
- *
- * {{{
- * ds: TypedDataset[(String, String, String)] -> ds.select(ds('_2)).run()
- * }}}
- *
- * Spark provides an alternative way to obtain the same resulting `Dataset`,
- * using the `map` method:
- *
- * {{{
- * ds: TypedDataset[(String, String, String)] -> ds.deserialized.map(_._2).run()
- * }}}
- *
- * This second approach is however substantially slower than the first one,
- * and should be avoided as possible. Indeed, under the hood this `map` will
- * deserialize the entire `Tuple3` to an full JVM object, call the apply
- * method of the `_._2` closure on it, and serialize the resulting String back
- * to its Catalyst representation.
- */
+ /**
+ * Methods on `TypedDataset[T]` that go through a full serialization and
+ * deserialization of `T`, and execute outside of the Catalyst runtime.
+ *
+ * @example The correct way to do a projection on a single column is to
+ * use the `select` method as follows:
+ *
+ * {{{
+ * ds: TypedDataset[(String, String, String)] -> ds.select(ds('_2)).run()
+ * }}}
+ *
+ * Spark provides an alternative way to obtain the same resulting `Dataset`,
+ * using the `map` method:
+ *
+ * {{{
+ * ds: TypedDataset[(String, String, String)] -> ds.deserialized.map(_._2).run()
+ * }}}
+ *
+ * This second approach is however substantially slower than the first one,
+ * and should be avoided as possible. Indeed, under the hood this `map` will
+ * deserialize the entire `Tuple3` to an full JVM object, call the apply
+ * method of the `_._2` closure on it, and serialize the resulting String back
+ * to its Catalyst representation.
+ */
object deserialized {
- /** Returns a new [[TypedDataset]] that contains the result of applying `func` to each element.
- *
- * apache/spark
- */
+
+ /**
+ * Returns a new [[TypedDataset]] that contains the result of applying `func` to each element.
+ *
+ * apache/spark
+ */
def map[U: TypedEncoder](func: T => U): TypedDataset[U] =
TypedDataset.create(self.dataset.map(func)(TypedExpressionEncoder[U]))
- /** Returns a new [[TypedDataset]] that contains the result of applying `func` to each partition.
- *
- * apache/spark
- */
- def mapPartitions[U: TypedEncoder](func: Iterator[T] => Iterator[U]): TypedDataset[U] =
- TypedDataset.create(self.dataset.mapPartitions(func)(TypedExpressionEncoder[U]))
-
- /** Returns a new [[TypedDataset]] by first applying a function to all elements of this [[TypedDataset]],
- * and then flattening the results.
- *
- * apache/spark
- */
- def flatMap[U: TypedEncoder](func: T => TraversableOnce[U]): TypedDataset[U] =
+ /**
+ * Returns a new [[TypedDataset]] that contains the result of applying `func` to each partition.
+ *
+ * apache/spark
+ */
+ def mapPartitions[U: TypedEncoder](
+ func: Iterator[T] => Iterator[U]
+ ): TypedDataset[U] =
+ TypedDataset.create(
+ self.dataset.mapPartitions(func)(TypedExpressionEncoder[U])
+ )
+
+ /**
+ * Returns a new [[TypedDataset]] by first applying a function to all elements of this [[TypedDataset]],
+ * and then flattening the results.
+ *
+ * apache/spark
+ */
+ def flatMap[U: TypedEncoder](
+ func: T => TraversableOnce[U]
+ ): TypedDataset[U] =
TypedDataset.create(self.dataset.flatMap(func)(TypedExpressionEncoder[U]))
- /** Returns a new [[TypedDataset]] that only contains elements where `func` returns `true`.
- *
- * apache/spark
- */
+ /**
+ * Returns a new [[TypedDataset]] that only contains elements where `func` returns `true`.
+ *
+ * apache/spark
+ */
def filter(func: T => Boolean): TypedDataset[T] =
TypedDataset.create(self.dataset.filter(func))
- /** Optionally reduces the elements of this [[TypedDataset]] using the specified binary function. The given
- * `func` must be commutative and associative or the result may be non-deterministic.
- *
- * Differs from `Dataset#reduce` by wrapping its result into an `Option` and an effect-suspending `F`.
- */
- def reduceOption[F[_]](func: (T, T) => T)(implicit F: SparkDelay[F]): F[Option[T]] =
+ /**
+ * Optionally reduces the elements of this [[TypedDataset]] using the specified binary function. The given
+ * `func` must be commutative and associative or the result may be non-deterministic.
+ *
+ * Differs from `Dataset#reduce` by wrapping its result into an `Option` and an effect-suspending `F`.
+ */
+ def reduceOption[F[_]](
+ func: (T, T) => T
+ )(implicit
+ F: SparkDelay[F]
+ ): F[Option[T]] =
F.delay {
try {
Option(self.dataset.reduce(func))
diff --git a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala
index 5b78cd292..121fd7fcc 100644
--- a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala
+++ b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala
@@ -3,18 +3,23 @@ package frameless
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If}
+import org.apache.spark.sql.catalyst.expressions.{
+ BoundReference,
+ CreateNamedStruct,
+ If
+}
import org.apache.spark.sql.types.StructType
object TypedExpressionEncoder {
- /** In Spark, DataFrame has always schema of StructType
- *
- * DataFrames of primitive types become records
- * with a single field called "value" set in ExpressionEncoder.
- */
+ /**
+ * In Spark, DataFrame has always schema of StructType
+ *
+ * DataFrames of primitive types become records
+ * with a single field called "value" set in ExpressionEncoder.
+ */
def targetStructType[A](encoder: TypedEncoder[A]): StructType =
- encoder.catalystRepr match {
+ encoder.catalystRepr match {
case x: StructType =>
if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true)))
else x
@@ -22,7 +27,10 @@ object TypedExpressionEncoder {
case dt => new StructType().add("value", dt, nullable = encoder.nullable)
}
- def apply[T](implicit encoder: TypedEncoder[T]): Encoder[T] = {
+ def apply[T](
+ implicit
+ encoder: TypedEncoder[T]
+ ): Encoder[T] = {
val in = BoundReference(0, encoder.jvmRepr, encoder.nullable)
val (out, serializer) = encoder.toCatalyst(in) match {
@@ -46,4 +54,3 @@ object TypedExpressionEncoder {
)
}
}
-
diff --git a/dataset/src/main/scala/frameless/With.scala b/dataset/src/main/scala/frameless/With.scala
index 11ceaa35b..85ce1d145 100644
--- a/dataset/src/main/scala/frameless/With.scala
+++ b/dataset/src/main/scala/frameless/With.scala
@@ -1,14 +1,15 @@
package frameless
-/** Compute the intersection of two types:
- *
- * - With[A, A] = A
- * - With[A, B] = A with B (when A != B)
- *
- * This type function is needed to prevent IDEs from infering large types
- * with shape `A with A with ... with A`. These types could be confusing for
- * both end users and IDE's type checkers.
- */
+/**
+ * Compute the intersection of two types:
+ *
+ * - With[A, A] = A
+ * - With[A, B] = A with B (when A != B)
+ *
+ * This type function is needed to prevent IDEs from infering large types
+ * with shape `A with A with ... with A`. These types could be confusing for
+ * both end users and IDE's type checkers.
+ */
trait With[A, B] { type Out }
object With extends LowPrioWith {
diff --git a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala
index e371ea048..61a8fdbab 100644
--- a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala
+++ b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala
@@ -3,71 +3,90 @@ package functions
import org.apache.spark.sql.FramelessInternals.expr
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.{functions => sparkFunctions}
+import org.apache.spark.sql.{ functions => sparkFunctions }
import frameless.syntax._
import scala.annotation.nowarn
trait AggregateFunctions {
- /** Aggregate function: returns the number of items in a group.
- *
- * apache/spark
- */
+
+ /**
+ * Aggregate function: returns the number of items in a group.
+ *
+ * apache/spark
+ */
def count[T](): TypedAggregate[T, Long] =
sparkFunctions.count(sparkFunctions.lit(1)).typedAggregate
- /** Aggregate function: returns the number of items in a group for which the selected column is not null.
- *
- * apache/spark
- */
+ /**
+ * Aggregate function: returns the number of items in a group for which the selected column is not null.
+ *
+ * apache/spark
+ */
def count[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] =
sparkFunctions.count(column.untyped).typedAggregate
- /** Aggregate function: returns the number of distinct items in a group.
- *
- * apache/spark
- */
+ /**
+ * Aggregate function: returns the number of distinct items in a group.
+ *
+ * apache/spark
+ */
def countDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] =
sparkFunctions.countDistinct(column.untyped).typedAggregate
- /** Aggregate function: returns the approximate number of distinct items in a group.
- */
- def approxCountDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] =
+ /**
+ * Aggregate function: returns the approximate number of distinct items in a group.
+ */
+ def approxCountDistinct[T](
+ column: TypedColumn[T, _]
+ ): TypedAggregate[T, Long] =
sparkFunctions.approx_count_distinct(column.untyped).typedAggregate
- /** Aggregate function: returns the approximate number of distinct items in a group.
- *
- * @param rsd maximum estimation error allowed (default = 0.05)
- *
- * apache/spark
- */
- def approxCountDistinct[T](column: TypedColumn[T, _], rsd: Double): TypedAggregate[T, Long] =
+ /**
+ * Aggregate function: returns the approximate number of distinct items in a group.
+ *
+ * @param rsd maximum estimation error allowed (default = 0.05)
+ *
+ * apache/spark
+ */
+ def approxCountDistinct[T](
+ column: TypedColumn[T, _],
+ rsd: Double
+ ): TypedAggregate[T, Long] =
sparkFunctions.approx_count_distinct(column.untyped, rsd).typedAggregate
- /** Aggregate function: returns a list of objects with duplicates.
- *
- * apache/spark
- */
- def collectList[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] =
+ /**
+ * Aggregate function: returns a list of objects with duplicates.
+ *
+ * apache/spark
+ */
+ def collectList[T, A: TypedEncoder](
+ column: TypedColumn[T, A]
+ ): TypedAggregate[T, Vector[A]] =
sparkFunctions.collect_list(column.untyped).typedAggregate
- /** Aggregate function: returns a set of objects with duplicate elements eliminated.
- *
- * apache/spark
- */
- def collectSet[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] =
+ /**
+ * Aggregate function: returns a set of objects with duplicate elements eliminated.
+ *
+ * apache/spark
+ */
+ def collectSet[T, A: TypedEncoder](
+ column: TypedColumn[T, A]
+ ): TypedAggregate[T, Vector[A]] =
sparkFunctions.collect_set(column.untyped).typedAggregate
- /** Aggregate function: returns the sum of all values in the given column.
- *
- * apache/spark
- */
- def sum[A, T, Out](column: TypedColumn[T, A])(
- implicit
- summable: CatalystSummable[A, Out],
- oencoder: TypedEncoder[Out],
- aencoder: TypedEncoder[A]
- ): TypedAggregate[T, Out] = {
+ /**
+ * Aggregate function: returns the sum of all values in the given column.
+ *
+ * apache/spark
+ */
+ def sum[A, T, Out](
+ column: TypedColumn[T, A]
+ )(implicit
+ summable: CatalystSummable[A, Out],
+ oencoder: TypedEncoder[Out],
+ aencoder: TypedEncoder[A]
+ ): TypedAggregate[T, Out] = {
val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr)
val sumExpr = expr(sparkFunctions.sum(column.untyped))
val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr))
@@ -75,17 +94,19 @@ trait AggregateFunctions {
new TypedAggregate[T, Out](sumOrZero)
}
- /** Aggregate function: returns the sum of distinct values in the column.
- *
- * apache/spark
- */
+ /**
+ * Aggregate function: returns the sum of distinct values in the column.
+ *
+ * apache/spark
+ */
@nowarn // supress sparkFunctions.sumDistinct call which is used to maintain Spark 3.1.x backwards compat
- def sumDistinct[A, T, Out](column: TypedColumn[T, A])(
- implicit
- summable: CatalystSummable[A, Out],
- oencoder: TypedEncoder[Out],
- aencoder: TypedEncoder[A]
- ): TypedAggregate[T, Out] = {
+ def sumDistinct[A, T, Out](
+ column: TypedColumn[T, A]
+ )(implicit
+ summable: CatalystSummable[A, Out],
+ oencoder: TypedEncoder[Out],
+ aencoder: TypedEncoder[A]
+ ): TypedAggregate[T, Out] = {
val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr)
val sumExpr = expr(sparkFunctions.sumDistinct(column.untyped))
val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr))
@@ -93,186 +114,225 @@ trait AggregateFunctions {
new TypedAggregate[T, Out](sumOrZero)
}
- /** Aggregate function: returns the average of the values in a group.
- *
- * apache/spark
- */
- def avg[A, T, Out](column: TypedColumn[T, A])(
- implicit
- averageable: CatalystAverageable[A, Out],
- oencoder: TypedEncoder[Out]
- ): TypedAggregate[T, Out] = {
+ /**
+ * Aggregate function: returns the average of the values in a group.
+ *
+ * apache/spark
+ */
+ def avg[A, T, Out](
+ column: TypedColumn[T, A]
+ )(implicit
+ averageable: CatalystAverageable[A, Out],
+ oencoder: TypedEncoder[Out]
+ ): TypedAggregate[T, Out] = {
new TypedAggregate[T, Out](sparkFunctions.avg(column.untyped))
}
- /** Aggregate function: returns the unbiased variance of the values in a group.
- *
- * @note In Spark variance always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#186]]
- *
- * apache/spark
- */
- def variance[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] =
+ /**
+ * Aggregate function: returns the unbiased variance of the values in a group.
+ *
+ * @note In Spark variance always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#186]]
+ *
+ * apache/spark
+ */
+ def variance[A: CatalystVariance, T](
+ column: TypedColumn[T, A]
+ ): TypedAggregate[T, Double] =
sparkFunctions.variance(column.untyped).typedAggregate
- /** Aggregate function: returns the sample standard deviation.
- *
- * @note In Spark stddev always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#155]]
- *
- * apache/spark
- */
- def stddev[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] =
+ /**
+ * Aggregate function: returns the sample standard deviation.
+ *
+ * @note In Spark stddev always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#155]]
+ *
+ * apache/spark
+ */
+ def stddev[A: CatalystVariance, T](
+ column: TypedColumn[T, A]
+ ): TypedAggregate[T, Double] =
sparkFunctions.stddev(column.untyped).typedAggregate
/**
- * Aggregate function: returns the standard deviation of a column by population.
- *
- * @note In Spark stddev always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L143]]
- *
- * apache/spark
- */
- def stddevPop[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = {
+ * Aggregate function: returns the standard deviation of a column by population.
+ *
+ * @note In Spark stddev always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L143]]
+ *
+ * apache/spark
+ */
+ def stddevPop[A, T](
+ column: TypedColumn[T, A]
+ )(implicit
+ ev: CatalystCast[A, Double]
+ ): TypedAggregate[T, Option[Double]] = {
new TypedAggregate[T, Option[Double]](
sparkFunctions.stddev_pop(column.cast[Double].untyped)
)
}
/**
- * Aggregate function: returns the standard deviation of a column by sample.
- *
- * @note In Spark stddev always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L160]]
- *
- * apache/spark
- */
- def stddevSamp[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double] ): TypedAggregate[T, Option[Double]] = {
+ * Aggregate function: returns the standard deviation of a column by sample.
+ *
+ * @note In Spark stddev always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L160]]
+ *
+ * apache/spark
+ */
+ def stddevSamp[A, T](
+ column: TypedColumn[T, A]
+ )(implicit
+ ev: CatalystCast[A, Double]
+ ): TypedAggregate[T, Option[Double]] = {
new TypedAggregate[T, Option[Double]](
sparkFunctions.stddev_samp(column.cast[Double].untyped)
)
}
- /** Aggregate function: returns the maximum value of the column in a group.
- *
- * apache/spark
- */
- def max[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = {
+ /**
+ * Aggregate function: returns the maximum value of the column in a group.
+ *
+ * apache/spark
+ */
+ def max[A: CatalystOrdered, T](
+ column: TypedColumn[T, A]
+ ): TypedAggregate[T, A] = {
implicit val c = column.uencoder
sparkFunctions.max(column.untyped).typedAggregate
}
- /** Aggregate function: returns the minimum value of the column in a group.
- *
- * apache/spark
- */
- def min[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = {
+ /**
+ * Aggregate function: returns the minimum value of the column in a group.
+ *
+ * apache/spark
+ */
+ def min[A: CatalystOrdered, T](
+ column: TypedColumn[T, A]
+ ): TypedAggregate[T, A] = {
implicit val c = column.uencoder
sparkFunctions.min(column.untyped).typedAggregate
}
- /** Aggregate function: returns the first value in a group.
- *
- * The function by default returns the first values it sees. It will return the first non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * apache/spark
- */
+ /**
+ * Aggregate function: returns the first value in a group.
+ *
+ * The function by default returns the first values it sees. It will return the first non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * apache/spark
+ */
def first[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = {
sparkFunctions.first(column.untyped).typedAggregate(column.uencoder)
}
/**
- * Aggregate function: returns the last value in a group.
- *
- * The function by default returns the last values it sees. It will return the last non-null
- * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
- *
- * apache/spark
- */
+ * Aggregate function: returns the last value in a group.
+ *
+ * The function by default returns the last values it sees. It will return the last non-null
+ * value it sees when ignoreNulls is set to true. If all values are null, then null is returned.
+ *
+ * apache/spark
+ */
def last[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = {
implicit val c = column.uencoder
sparkFunctions.last(column.untyped).typedAggregate
}
/**
- * Aggregate function: returns the Pearson Correlation Coefficient for two columns.
- *
- * @note In Spark corr always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala#L95]]
- *
- * apache/spark
- */
- def corr[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B])
- (implicit
+ * Aggregate function: returns the Pearson Correlation Coefficient for two columns.
+ *
+ * @note In Spark corr always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala#L95]]
+ *
+ * apache/spark
+ */
+ def corr[A, B, T](
+ column1: TypedColumn[T, A],
+ column2: TypedColumn[T, B]
+ )(implicit
i0: CatalystCast[A, Double],
i1: CatalystCast[B, Double]
): TypedAggregate[T, Option[Double]] = {
- new TypedAggregate[T, Option[Double]](
- sparkFunctions.corr(column1.cast[Double].untyped, column2.cast[Double].untyped)
- )
- }
+ new TypedAggregate[T, Option[Double]](
+ sparkFunctions
+ .corr(column1.cast[Double].untyped, column2.cast[Double].untyped)
+ )
+ }
/**
- * Aggregate function: returns the covariance of two collumns.
- *
- * @note In Spark covar_pop always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L82]]
- *
- * apache/spark
- */
- def covarPop[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B])
- (implicit
+ * Aggregate function: returns the covariance of two collumns.
+ *
+ * @note In Spark covar_pop always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L82]]
+ *
+ * apache/spark
+ */
+ def covarPop[A, B, T](
+ column1: TypedColumn[T, A],
+ column2: TypedColumn[T, B]
+ )(implicit
i0: CatalystCast[A, Double],
i1: CatalystCast[B, Double]
): TypedAggregate[T, Option[Double]] = {
- new TypedAggregate[T, Option[Double]](
- sparkFunctions.covar_pop(column1.cast[Double].untyped, column2.cast[Double].untyped)
- )
- }
+ new TypedAggregate[T, Option[Double]](
+ sparkFunctions
+ .covar_pop(column1.cast[Double].untyped, column2.cast[Double].untyped)
+ )
+ }
/**
- * Aggregate function: returns the covariance of two columns.
- *
- * @note In Spark covar_samp always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L93]]
- *
- * apache/spark
- */
- def covarSamp[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B])
- (implicit
+ * Aggregate function: returns the covariance of two columns.
+ *
+ * @note In Spark covar_samp always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L93]]
+ *
+ * apache/spark
+ */
+ def covarSamp[A, B, T](
+ column1: TypedColumn[T, A],
+ column2: TypedColumn[T, B]
+ )(implicit
i0: CatalystCast[A, Double],
i1: CatalystCast[B, Double]
): TypedAggregate[T, Option[Double]] = {
- new TypedAggregate[T, Option[Double]](
- sparkFunctions.covar_samp(column1.cast[Double].untyped, column2.cast[Double].untyped)
- )
- }
-
+ new TypedAggregate[T, Option[Double]](
+ sparkFunctions
+ .covar_samp(column1.cast[Double].untyped, column2.cast[Double].untyped)
+ )
+ }
/**
- * Aggregate function: returns the kurtosis of a column.
- *
- * @note In Spark kurtosis always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L220]]
- *
- * apache/spark
- */
- def kurtosis[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = {
+ * Aggregate function: returns the kurtosis of a column.
+ *
+ * @note In Spark kurtosis always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L220]]
+ *
+ * apache/spark
+ */
+ def kurtosis[A, T](
+ column: TypedColumn[T, A]
+ )(implicit
+ ev: CatalystCast[A, Double]
+ ): TypedAggregate[T, Option[Double]] = {
new TypedAggregate[T, Option[Double]](
sparkFunctions.kurtosis(column.cast[Double].untyped)
)
}
/**
- * Aggregate function: returns the skewness of a column.
- *
- * @note In Spark skewness always returns Double
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L200]]
- *
- * apache/spark
- */
- def skewness[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = {
+ * Aggregate function: returns the skewness of a column.
+ *
+ * @note In Spark skewness always returns Double
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L200]]
+ *
+ * apache/spark
+ */
+ def skewness[A, T](
+ column: TypedColumn[T, A]
+ )(implicit
+ ev: CatalystCast[A, Double]
+ ): TypedAggregate[T, Option[Double]] = {
new TypedAggregate[T, Option[Double]](
sparkFunctions.skewness(column.cast[Double].untyped)
)
diff --git a/dataset/src/main/scala/frameless/functions/Lit.scala b/dataset/src/main/scala/frameless/functions/Lit.scala
index d01467b13..78ee2fd3b 100644
--- a/dataset/src/main/scala/frameless/functions/Lit.scala
+++ b/dataset/src/main/scala/frameless/functions/Lit.scala
@@ -2,7 +2,10 @@ package frameless.functions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression}
+import org.apache.spark.sql.catalyst.expressions.{
+ Expression,
+ NonSQLExpression
+}
import org.apache.spark.sql.types.DataType
private[frameless] case class Lit[T <: AnyVal](
@@ -10,7 +13,8 @@ private[frameless] case class Lit[T <: AnyVal](
nullable: Boolean,
show: () => String,
catalystExpr: Expression // must be a generated Expression from a literal TypedEncoder's toCatalyst function
-) extends Expression with NonSQLExpression {
+ ) extends Expression
+ with NonSQLExpression {
override def toString: String = s"FramelessLit(${show()})"
lazy val codegen = {
@@ -52,12 +56,15 @@ private[frameless] case class Lit[T <: AnyVal](
}
def eval(input: InternalRow): Any = codegen(input)
-
+
def children: Seq[Expression] = Nil
- protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = catalystExpr.genCode(ctx)
+ protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ catalystExpr.genCode(ctx)
- protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this
+ protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]
+ ): Expression = this
override val foldable: Boolean = catalystExpr.foldable
}
diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala
index 939bf5b8d..4ce3c63cf 100644
--- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala
+++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala
@@ -1,537 +1,729 @@
package frameless
package functions
-import org.apache.spark.sql.{Column, functions => sparkFunctions}
+import org.apache.spark.sql.{ Column, functions => sparkFunctions }
import scala.annotation.nowarn
import scala.util.matching.Regex
trait NonAggregateFunctions {
- /** Non-Aggregate function: calculates the SHA-2 digest of a binary column and returns the value as a 40 character hex string
- *
- * apache/spark
- */
- def sha2[T](column: AbstractTypedColumn[T, Array[Byte]], numBits: Int): column.ThisType[T, String] =
+
+ /**
+ * Non-Aggregate function: calculates the SHA-2 digest of a binary column and returns the value as a 40 character hex string
+ *
+ * apache/spark
+ */
+ def sha2[T](
+ column: AbstractTypedColumn[T, Array[Byte]],
+ numBits: Int
+ ): column.ThisType[T, String] =
column.typed(sparkFunctions.sha2(column.untyped, numBits))
- /** Non-Aggregate function: calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex string
- *
- * apache/spark
- */
- def sha1[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] =
+ /**
+ * Non-Aggregate function: calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex string
+ *
+ * apache/spark
+ */
+ def sha1[T](
+ column: AbstractTypedColumn[T, Array[Byte]]
+ ): column.ThisType[T, String] =
column.typed(sparkFunctions.sha1(column.untyped))
- /** Non-Aggregate function: returns a cyclic redundancy check value of a binary column as long.
- *
- * apache/spark
- */
- def crc32[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, Long] =
+ /**
+ * Non-Aggregate function: returns a cyclic redundancy check value of a binary column as long.
+ *
+ * apache/spark
+ */
+ def crc32[T](
+ column: AbstractTypedColumn[T, Array[Byte]]
+ ): column.ThisType[T, Long] =
column.typed(sparkFunctions.crc32(column.untyped))
+
/**
- * Non-Aggregate function: returns the negated value of column.
- *
- * apache/spark
- */
- def negate[A, B, T](column: AbstractTypedColumn[T,A])(
- implicit i0: CatalystNumericWithJavaBigDecimal[A, B],
- i1: TypedEncoder[B]
- ): column.ThisType[T,B] =
+ * Non-Aggregate function: returns the negated value of column.
+ *
+ * apache/spark
+ */
+ def negate[A, B, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystNumericWithJavaBigDecimal[A, B],
+ i1: TypedEncoder[B]
+ ): column.ThisType[T, B] =
column.typed(sparkFunctions.negate(column.untyped))
/**
- * Non-Aggregate function: logical not.
- *
- * apache/spark
- */
- def not[T](column: AbstractTypedColumn[T,Boolean]): column.ThisType[T,Boolean] =
+ * Non-Aggregate function: logical not.
+ *
+ * apache/spark
+ */
+ def not[T](
+ column: AbstractTypedColumn[T, Boolean]
+ ): column.ThisType[T, Boolean] =
column.typed(sparkFunctions.not(column.untyped))
/**
- * Non-Aggregate function: Convert a number in a string column from one base to another.
- *
- * apache/spark
- */
- def conv[T](column: AbstractTypedColumn[T,String], fromBase: Int, toBase: Int): column.ThisType[T,String] =
- column.typed(sparkFunctions.conv(column.untyped,fromBase,toBase))
+ * Non-Aggregate function: Convert a number in a string column from one base to another.
+ *
+ * apache/spark
+ */
+ def conv[T](
+ column: AbstractTypedColumn[T, String],
+ fromBase: Int,
+ toBase: Int
+ ): column.ThisType[T, String] =
+ column.typed(sparkFunctions.conv(column.untyped, fromBase, toBase))
- /** Non-Aggregate function: Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
- *
- * apache/spark
- */
- def degrees[A,T](column: AbstractTypedColumn[T,A]): column.ThisType[T,Double] =
+ /**
+ * Non-Aggregate function: Converts an angle measured in radians to an approximately equivalent angle measured in degrees.
+ *
+ * apache/spark
+ */
+ def degrees[A, T](
+ column: AbstractTypedColumn[T, A]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.degrees(column.untyped))
- /** Non-Aggregate function: returns the ceiling of a numeric column
- *
- * apache/spark
- */
- def ceil[A, B, T](column: AbstractTypedColumn[T, A])
- (implicit
+ /**
+ * Non-Aggregate function: returns the ceiling of a numeric column
+ *
+ * apache/spark
+ */
+ def ceil[A, B, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystRound[A, B],
+ i1: TypedEncoder[B]
+ ): column.ThisType[T, B] =
+ column.typed(sparkFunctions.ceil(column.untyped))(i1)
+
+ /**
+ * Non-Aggregate function: returns the floor of a numeric column
+ *
+ * apache/spark
+ */
+ def floor[A, B, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
i0: CatalystRound[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
- column.typed(sparkFunctions.ceil(column.untyped))(i1)
-
- /** Non-Aggregate function: returns the floor of a numeric column
- *
- * apache/spark
- */
- def floor[A, B, T](column: AbstractTypedColumn[T, A])
- (implicit
- i0: CatalystRound[A, B],
- i1: TypedEncoder[B]
- ): column.ThisType[T, B] =
column.typed(sparkFunctions.floor(column.untyped))(i1)
- /** Non-Aggregate function: unsigned shift the the given value numBits right. If given long, will return long else it will return an integer.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: unsigned shift the the given value numBits right. If given long, will return long else it will return an integer.
+ *
+ * apache/spark
+ */
@nowarn // supress sparkFunctions.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat
- def shiftRightUnsigned[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
- (implicit
+ def shiftRightUnsigned[A, B, T](
+ column: AbstractTypedColumn[T, A],
+ numBits: Int
+ )(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
- column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits))
+ column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits))
- /** Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer.
+ *
+ * apache/spark
+ */
@nowarn // supress sparkFunctions.shiftReft call which is used to maintain Spark 3.1.x backwards compat
- def shiftRight[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
- (implicit
+ def shiftRight[A, B, T](
+ column: AbstractTypedColumn[T, A],
+ numBits: Int
+ )(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
- column.typed(sparkFunctions.shiftRight(column.untyped, numBits))
+ column.typed(sparkFunctions.shiftRight(column.untyped, numBits))
- /** Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer.
+ *
+ * apache/spark
+ */
@nowarn // supress sparkFunctions.shiftLeft call which is used to maintain Spark 3.1.x backwards compat
- def shiftLeft[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
- (implicit
+ def shiftLeft[A, B, T](
+ column: AbstractTypedColumn[T, A],
+ numBits: Int
+ )(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
column.typed(sparkFunctions.shiftLeft(column.untyped, numBits))
-
- /** Non-Aggregate function: returns the absolute value of a numeric column
- *
- * apache/spark
- */
- def abs[A, B, T](column: AbstractTypedColumn[T, A])
- (implicit
- i0: CatalystNumericWithJavaBigDecimal[A, B],
- i1: TypedEncoder[B]
+
+ /**
+ * Non-Aggregate function: returns the absolute value of a numeric column
+ *
+ * apache/spark
+ */
+ def abs[A, B, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystNumericWithJavaBigDecimal[A, B],
+ i1: TypedEncoder[B]
): column.ThisType[T, B] =
- column.typed(sparkFunctions.abs(column.untyped))(i1)
-
- /** Non-Aggregate function: Computes the cosine of the given value.
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def cos[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.cos(column.cast[Double].untyped))
-
- /** Non-Aggregate function: Computes the hyperbolic cosine of the given value.
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def cosh[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.cosh(column.cast[Double].untyped))
-
- /** Non-Aggregate function: Computes the signum of the given value.
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def signum[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.abs(column.untyped))(i1)
+
+ /**
+ * Non-Aggregate function: Computes the cosine of the given value.
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def cos[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.cos(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: Computes the hyperbolic cosine of the given value.
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def cosh[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.cosh(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: Computes the signum of the given value.
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def signum[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.signum(column.cast[Double].untyped))
- /** Non-Aggregate function: Computes the sine of the given value.
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def sin[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.sin(column.cast[Double].untyped))
-
- /** Non-Aggregate function: Computes the hyperbolic sine of the given value.
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def sinh[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.sinh(column.cast[Double].untyped))
-
- /** Non-Aggregate function: Computes the tangent of the given column.
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def tan[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.tan(column.cast[Double].untyped))
-
- /** Non-Aggregate function: Computes the hyperbolic tangent of the given value.
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def tanh[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.tanh(column.cast[Double].untyped))
-
- /** Non-Aggregate function: returns the acos of a numeric column
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def acos[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.acos(column.cast[Double].untyped))
-
- /** Non-Aggregate function: returns true if value is contained with in the array in the specified column
- *
- * apache/spark
- */
- def arrayContains[C[_]: CatalystCollection, A, T](column: AbstractTypedColumn[T, C[A]], value: A): column.ThisType[T, Boolean] =
+ /**
+ * Non-Aggregate function: Computes the sine of the given value.
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def sin[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.sin(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: Computes the hyperbolic sine of the given value.
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def sinh[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.sinh(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: Computes the tangent of the given column.
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def tan[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.tan(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: Computes the hyperbolic tangent of the given value.
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def tanh[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.tanh(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: returns the acos of a numeric column
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def acos[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.acos(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: returns true if value is contained with in the array in the specified column
+ *
+ * apache/spark
+ */
+ def arrayContains[C[_]: CatalystCollection, A, T](
+ column: AbstractTypedColumn[T, C[A]],
+ value: A
+ ): column.ThisType[T, Boolean] =
column.typed(sparkFunctions.array_contains(column.untyped, value))
- /** Non-Aggregate function: returns the atan of a numeric column
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def atan[A, T](column: AbstractTypedColumn[T,A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.atan(column.cast[Double].untyped))
-
- /** Non-Aggregate function: returns the asin of a numeric column
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def asin[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
- column.typed(sparkFunctions.asin(column.cast[Double].untyped))
-
- /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def atan2[A, B, T](l: TypedColumn[T, A], r: TypedColumn[T, B])
- (implicit
+ /**
+ * Non-Aggregate function: returns the atan of a numeric column
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def atan[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.atan(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: returns the asin of a numeric column
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def asin[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
+ column.typed(sparkFunctions.asin(column.cast[Double].untyped))
+
+ /**
+ * Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to
+ * polar coordinates (r, theta).
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def atan2[A, B, T](
+ l: TypedColumn[T, A],
+ r: TypedColumn[T, B]
+ )(implicit
i0: CatalystCast[A, Double],
i1: CatalystCast[B, Double]
): TypedColumn[T, Double] =
- r.typed(sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped))
-
- /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to
- * polar coordinates (r, theta).
- *
- * Spark will expect a Double value for this expression. See:
- * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
- * apache/spark
- */
- def atan2[A, B, T](l: TypedAggregate[T, A], r: TypedAggregate[T, B])
- (implicit
+ r.typed(
+ sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped)
+ )
+
+ /**
+ * Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to
+ * polar coordinates (r, theta).
+ *
+ * Spark will expect a Double value for this expression. See:
+ * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]]
+ * apache/spark
+ */
+ def atan2[A, B, T](
+ l: TypedAggregate[T, A],
+ r: TypedAggregate[T, B]
+ )(implicit
i0: CatalystCast[A, Double],
i1: CatalystCast[B, Double]
): TypedAggregate[T, Double] =
- r.typed(sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped))
-
- def atan2[B, T](l: Double, r: TypedColumn[T, B])
- (implicit i0: CatalystCast[B, Double]): TypedColumn[T, Double] =
- atan2(r.lit(l), r)
-
- def atan2[A, T](l: TypedColumn[T, A], r: Double)
- (implicit i0: CatalystCast[A, Double]): TypedColumn[T, Double] =
- atan2(l, l.lit(r))
-
- def atan2[B, T](l: Double, r: TypedAggregate[T, B])
- (implicit i0: CatalystCast[B, Double]): TypedAggregate[T, Double] =
- atan2(r.lit(l), r)
-
- def atan2[A, T](l: TypedAggregate[T, A], r: Double)
- (implicit i0: CatalystCast[A, Double]): TypedAggregate[T, Double] =
- atan2(l, l.lit(r))
-
- /** Non-Aggregate function: returns the square root value of a numeric column.
- *
- * apache/spark
- */
- def sqrt[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
+ r.typed(
+ sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped)
+ )
+
+ def atan2[B, T](
+ l: Double,
+ r: TypedColumn[T, B]
+ )(implicit
+ i0: CatalystCast[B, Double]
+ ): TypedColumn[T, Double] =
+ atan2(r.lit(l), r)
+
+ def atan2[A, T](
+ l: TypedColumn[T, A],
+ r: Double
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): TypedColumn[T, Double] =
+ atan2(l, l.lit(r))
+
+ def atan2[B, T](
+ l: Double,
+ r: TypedAggregate[T, B]
+ )(implicit
+ i0: CatalystCast[B, Double]
+ ): TypedAggregate[T, Double] =
+ atan2(r.lit(l), r)
+
+ def atan2[A, T](
+ l: TypedAggregate[T, A],
+ r: Double
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): TypedAggregate[T, Double] =
+ atan2(l, l.lit(r))
+
+ /**
+ * Non-Aggregate function: returns the square root value of a numeric column.
+ *
+ * apache/spark
+ */
+ def sqrt[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.sqrt(column.cast[Double].untyped))
- /** Non-Aggregate function: returns the cubic root value of a numeric column.
- *
- * apache/spark
- */
- def cbrt[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
+ /**
+ * Non-Aggregate function: returns the cubic root value of a numeric column.
+ *
+ * apache/spark
+ */
+ def cbrt[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.cbrt(column.cast[Double].untyped))
- /** Non-Aggregate function: returns the exponential value of a numeric column.
- *
- * apache/spark
- */
- def exp[A, T](column: AbstractTypedColumn[T, A])
- (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] =
+ /**
+ * Non-Aggregate function: returns the exponential value of a numeric column.
+ *
+ * apache/spark
+ */
+ def exp[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.exp(column.cast[Double].untyped))
- /** Non-Aggregate function: Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode.
- *
- * apache/spark
- */
- def round[A, B, T](column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B]
- ): column.ThisType[T, B] =
+ /**
+ * Non-Aggregate function: Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode.
+ *
+ * apache/spark
+ */
+ def round[A, B, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystNumericWithJavaBigDecimal[A, B],
+ i1: TypedEncoder[B]
+ ): column.ThisType[T, B] =
column.typed(sparkFunctions.round(column.untyped))(i1)
- /** Non-Aggregate function: Round the value of `e` to `scale` decimal places with HALF_UP round mode
- * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
- *
- * apache/spark
- */
- def round[A, B, T](column: AbstractTypedColumn[T, A], scale: Int)(
- implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B]
- ): column.ThisType[T, B] =
+ /**
+ * Non-Aggregate function: Round the value of `e` to `scale` decimal places with HALF_UP round mode
+ * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
+ *
+ * apache/spark
+ */
+ def round[A, B, T](
+ column: AbstractTypedColumn[T, A],
+ scale: Int
+ )(implicit
+ i0: CatalystNumericWithJavaBigDecimal[A, B],
+ i1: TypedEncoder[B]
+ ): column.ThisType[T, B] =
column.typed(sparkFunctions.round(column.untyped, scale))(i1)
- /** Non-Aggregate function: Bankers Rounding - returns the rounded to 0 decimal places value with HALF_EVEN round mode
- * of a numeric column.
- *
- * apache/spark
- */
- def bround[A, B, T](column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B]
- ): column.ThisType[T, B] =
+ /**
+ * Non-Aggregate function: Bankers Rounding - returns the rounded to 0 decimal places value with HALF_EVEN round mode
+ * of a numeric column.
+ *
+ * apache/spark
+ */
+ def bround[A, B, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystNumericWithJavaBigDecimal[A, B],
+ i1: TypedEncoder[B]
+ ): column.ThisType[T, B] =
column.typed(sparkFunctions.bround(column.untyped))(i1)
- /** Non-Aggregate function: Bankers Rounding - returns the rounded to `scale` decimal places value with HALF_EVEN round mode
- * of a numeric column. If `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
- *
- * apache/spark
- */
- def bround[A, B, T](column: AbstractTypedColumn[T, A], scale: Int)(
- implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B]
- ): column.ThisType[T, B] =
+ /**
+ * Non-Aggregate function: Bankers Rounding - returns the rounded to `scale` decimal places value with HALF_EVEN round mode
+ * of a numeric column. If `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0.
+ *
+ * apache/spark
+ */
+ def bround[A, B, T](
+ column: AbstractTypedColumn[T, A],
+ scale: Int
+ )(implicit
+ i0: CatalystNumericWithJavaBigDecimal[A, B],
+ i1: TypedEncoder[B]
+ ): column.ThisType[T, B] =
column.typed(sparkFunctions.bround(column.untyped, scale))(i1)
/**
- * Computes the natural logarithm of the given value.
- *
- * apache/spark
- */
- def log[A, T](column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Computes the natural logarithm of the given value.
+ *
+ * apache/spark
+ */
+ def log[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.log(column.untyped))
/**
- * Returns the first argument-base logarithm of the second argument.
- *
- * apache/spark
- */
- def log[A, T](base: Double, column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Returns the first argument-base logarithm of the second argument.
+ *
+ * apache/spark
+ */
+ def log[A, T](
+ base: Double,
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.log(base, column.untyped))
/**
- * Computes the logarithm of the given column in base 2.
- *
- * apache/spark
- */
- def log2[A, T](column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Computes the logarithm of the given column in base 2.
+ *
+ * apache/spark
+ */
+ def log2[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.log2(column.untyped))
/**
- * Computes the natural logarithm of the given value plus one.
- *
- * apache/spark
- */
- def log1p[A, T](column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Computes the natural logarithm of the given value plus one.
+ *
+ * apache/spark
+ */
+ def log1p[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.log1p(column.untyped))
/**
- * Computes the logarithm of the given column in base 10.
- *
- * apache/spark
- */
- def log10[A, T](column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Computes the logarithm of the given column in base 10.
+ *
+ * apache/spark
+ */
+ def log10[A, T](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.log10(column.untyped))
-
/**
- * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
- *
- * apache/spark
- */
- def hypot[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
+ *
+ * apache/spark
+ */
+ def hypot[A, T](
+ column: AbstractTypedColumn[T, A],
+ column2: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.hypot(column.untyped, column2.untyped))
/**
- * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
- *
- * apache/spark
- */
- def hypot[A, T](column: AbstractTypedColumn[T, A], l: Double)(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
+ *
+ * apache/spark
+ */
+ def hypot[A, T](
+ column: AbstractTypedColumn[T, A],
+ l: Double
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.hypot(column.untyped, l))
/**
- * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
- *
- * apache/spark
- */
- def hypot[A, T](l: Double, column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.
+ *
+ * apache/spark
+ */
+ def hypot[A, T](
+ l: Double,
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.hypot(l, column.untyped))
/**
- * Returns the value of the first argument raised to the power of the second argument.
- *
- * apache/spark
- */
- def pow[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Returns the value of the first argument raised to the power of the second argument.
+ *
+ * apache/spark
+ */
+ def pow[A, T](
+ column: AbstractTypedColumn[T, A],
+ column2: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.pow(column.untyped, column2.untyped))
/**
- * Returns the value of the first argument raised to the power of the second argument.
- *
- * apache/spark
- */
- def pow[A, T](column: AbstractTypedColumn[T, A], l: Double)(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Returns the value of the first argument raised to the power of the second argument.
+ *
+ * apache/spark
+ */
+ def pow[A, T](
+ column: AbstractTypedColumn[T, A],
+ l: Double
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.pow(column.untyped, l))
/**
- * Returns the value of the first argument raised to the power of the second argument.
- *
- * apache/spark
- */
- def pow[A, T](l: Double, column: AbstractTypedColumn[T, A])(
- implicit i0: CatalystCast[A, Double]
- ): column.ThisType[T, Double] =
+ * Returns the value of the first argument raised to the power of the second argument.
+ *
+ * apache/spark
+ */
+ def pow[A, T](
+ l: Double,
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: CatalystCast[A, Double]
+ ): column.ThisType[T, Double] =
column.typed(sparkFunctions.pow(l, column.untyped))
/**
- * Returns the positive value of dividend mod divisor.
- *
- * apache/spark
- */
- def pmod[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])(
- implicit i0: TypedEncoder[A]
- ): column.ThisType[T, A] =
+ * Returns the positive value of dividend mod divisor.
+ *
+ * apache/spark
+ */
+ def pmod[A, T](
+ column: AbstractTypedColumn[T, A],
+ column2: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: TypedEncoder[A]
+ ): column.ThisType[T, A] =
column.typed(sparkFunctions.pmod(column.untyped, column2.untyped))
-
- /** Non-Aggregate function: Returns the string representation of the binary value of the given long
- * column. For example, bin("12") returns "1100".
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Returns the string representation of the binary value of the given long
+ * column. For example, bin("12") returns "1100".
+ *
+ * apache/spark
+ */
def bin[T](column: AbstractTypedColumn[T, Long]): column.ThisType[T, String] =
column.typed(sparkFunctions.bin(column.untyped))
/**
- * Calculates the MD5 digest of a binary column and returns the value
- * as a 32 character hex string.
- *
- * apache/spark
- */
- def md5[T, A](column: AbstractTypedColumn[T, A])(implicit i0: TypedEncoder[A]): column.ThisType[T, String] =
+ * Calculates the MD5 digest of a binary column and returns the value
+ * as a 32 character hex string.
+ *
+ * apache/spark
+ */
+ def md5[T, A](
+ column: AbstractTypedColumn[T, A]
+ )(implicit
+ i0: TypedEncoder[A]
+ ): column.ThisType[T, String] =
column.typed(sparkFunctions.md5(column.untyped))
/**
- * Computes the factorial of the given value.
- *
- * apache/spark
- */
- def factorial[T](column: AbstractTypedColumn[T, Long])(implicit i0: TypedEncoder[Long]): column.ThisType[T, Long] =
+ * Computes the factorial of the given value.
+ *
+ * apache/spark
+ */
+ def factorial[T](
+ column: AbstractTypedColumn[T, Long]
+ )(implicit
+ i0: TypedEncoder[Long]
+ ): column.ThisType[T, Long] =
column.typed(sparkFunctions.factorial(column.untyped))
- /** Non-Aggregate function: Computes bitwise NOT.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Computes bitwise NOT.
+ *
+ * apache/spark
+ */
@nowarn // supress sparkFunctions.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat
- def bitwiseNOT[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] =
+ def bitwiseNOT[A: CatalystBitwise, T](
+ column: AbstractTypedColumn[T, A]
+ ): column.ThisType[T, A] =
column.typed(sparkFunctions.bitwiseNOT(column.untyped))(column.uencoder)
- /** Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from
- * a file
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from
+ * a file
+ *
+ * apache/spark
+ */
def inputFileName[T](): TypedColumn[T, String] =
new TypedColumn[T, String](sparkFunctions.input_file_name())
- /** Non-Aggregate function: generates monotonically increasing id
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: generates monotonically increasing id
+ *
+ * apache/spark
+ */
def monotonicallyIncreasingId[T](): TypedColumn[T, Long] = {
new TypedColumn[T, Long](sparkFunctions.monotonically_increasing_id())
}
- /** Non-Aggregate function: Evaluates a list of conditions and returns one of multiple
- * possible result expressions. If none match, otherwise is returned
- * {{{
- * when(ds('boolField), ds('a))
- * .when(ds('otherBoolField), lit(123))
- * .otherwise(ds('b))
- * }}}
- * apache/spark
- */
- def when[T, A](condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] =
+ /**
+ * Non-Aggregate function: Evaluates a list of conditions and returns one of multiple
+ * possible result expressions. If none match, otherwise is returned
+ * {{{
+ * when(ds('boolField), ds('a))
+ * .when(ds('otherBoolField), lit(123))
+ * .otherwise(ds('b))
+ * }}}
+ * apache/spark
+ */
+ def when[T, A](
+ condition: AbstractTypedColumn[T, Boolean],
+ value: AbstractTypedColumn[T, A]
+ ): When[T, A] =
new When[T, A](condition, value)
class When[T, A] private (untypedC: Column) {
- private[functions] def this(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]) =
+ private[functions] def this(
+ condition: AbstractTypedColumn[T, Boolean],
+ value: AbstractTypedColumn[T, A]
+ ) =
this(sparkFunctions.when(condition.untyped, value.untyped))
- def when(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] =
+ def when(
+ condition: AbstractTypedColumn[T, Boolean],
+ value: AbstractTypedColumn[T, A]
+ ): When[T, A] =
new When[T, A](untypedC.when(condition.untyped, value.untyped))
def otherwise(value: AbstractTypedColumn[T, A]): value.ThisType[T, A] =
@@ -542,172 +734,230 @@ trait NonAggregateFunctions {
// String functions
//////////////////////////////////////////////////////////////////////////////////////////////
-
- /** Non-Aggregate function: takes the first letter of a string column and returns the ascii int value in a new column
- *
- * apache/spark
- */
- def ascii[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Int] =
+ /**
+ * Non-Aggregate function: takes the first letter of a string column and returns the ascii int value in a new column
+ *
+ * apache/spark
+ */
+ def ascii[T](
+ column: AbstractTypedColumn[T, String]
+ ): column.ThisType[T, Int] =
column.typed(sparkFunctions.ascii(column.untyped))
- /** Non-Aggregate function: Computes the BASE64 encoding of a binary column and returns it as a string column.
- * This is the reverse of unbase64.
- *
- * apache/spark
- */
- def base64[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] =
+ /**
+ * Non-Aggregate function: Computes the BASE64 encoding of a binary column and returns it as a string column.
+ * This is the reverse of unbase64.
+ *
+ * apache/spark
+ */
+ def base64[T](
+ column: AbstractTypedColumn[T, Array[Byte]]
+ ): column.ThisType[T, String] =
column.typed(sparkFunctions.base64(column.untyped))
- /** Non-Aggregate function: Decodes a BASE64 encoded string column and returns it as a binary column.
- * This is the reverse of base64.
- *
- * apache/spark
- */
- def unbase64[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Array[Byte]] =
+ /**
+ * Non-Aggregate function: Decodes a BASE64 encoded string column and returns it as a binary column.
+ * This is the reverse of base64.
+ *
+ * apache/spark
+ */
+ def unbase64[T](
+ column: AbstractTypedColumn[T, String]
+ ): column.ThisType[T, Array[Byte]] =
column.typed(sparkFunctions.unbase64(column.untyped))
- /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column.
- * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Concatenates multiple input string columns together into a single string column.
+ * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
+ *
+ * apache/spark
+ */
def concat[T](columns: TypedColumn[T, String]*): TypedColumn[T, String] =
new TypedColumn(sparkFunctions.concat(columns.map(_.untyped): _*))
- /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column.
- * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
- *
- * apache/spark
- */
- def concat[T](columns: TypedAggregate[T, String]*): TypedAggregate[T, String] =
+ /**
+ * Non-Aggregate function: Concatenates multiple input string columns together into a single string column.
+ * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
+ *
+ * apache/spark
+ */
+ def concat[T](
+ columns: TypedAggregate[T, String]*
+ ): TypedAggregate[T, String] =
new TypedAggregate(sparkFunctions.concat(columns.map(_.untyped): _*))
- /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column,
- * using the given separator.
- * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
- *
- * apache/spark
- */
- def concatWs[T](sep: String, columns: TypedAggregate[T, String]*): TypedAggregate[T, String] =
- new TypedAggregate(sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*))
-
- /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column,
- * using the given separator.
- * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
- *
- * apache/spark
- */
- def concatWs[T](sep: String, columns: TypedColumn[T, String]*): TypedColumn[T, String] =
+ /**
+ * Non-Aggregate function: Concatenates multiple input string columns together into a single string column,
+ * using the given separator.
+ * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
+ *
+ * apache/spark
+ */
+ def concatWs[T](
+ sep: String,
+ columns: TypedAggregate[T, String]*
+ ): TypedAggregate[T, String] =
+ new TypedAggregate(
+ sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*)
+ )
+
+ /**
+ * Non-Aggregate function: Concatenates multiple input string columns together into a single string column,
+ * using the given separator.
+ * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]]
+ *
+ * apache/spark
+ */
+ def concatWs[T](
+ sep: String,
+ columns: TypedColumn[T, String]*
+ ): TypedColumn[T, String] =
new TypedColumn(sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*))
- /** Non-Aggregate function: Locates the position of the first occurrence of substring column
- * in given string
- *
- * @note The position is not zero based, but 1 based index. Returns 0 if substr
- * could not be found in str.
- *
- * apache/spark
- */
- def instr[T](str: AbstractTypedColumn[T, String], substring: String): str.ThisType[T, Int] =
+ /**
+ * Non-Aggregate function: Locates the position of the first occurrence of substring column
+ * in given string
+ *
+ * @note The position is not zero based, but 1 based index. Returns 0 if substr
+ * could not be found in str.
+ *
+ * apache/spark
+ */
+ def instr[T](
+ str: AbstractTypedColumn[T, String],
+ substring: String
+ ): str.ThisType[T, Int] =
str.typed(sparkFunctions.instr(str.untyped, substring))
- /** Non-Aggregate function: Computes the length of a given string.
- *
- * apache/spark
- */
- //TODO: Also for binary
+ /**
+ * Non-Aggregate function: Computes the length of a given string.
+ *
+ * apache/spark
+ */
+ // TODO: Also for binary
def length[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Int] =
str.typed(sparkFunctions.length(str.untyped))
- /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns.
- *
- * apache/spark
- */
- def levenshtein[T](l: TypedColumn[T, String], r: TypedColumn[T, String]): TypedColumn[T, Int] =
+ /**
+ * Non-Aggregate function: Computes the Levenshtein distance of the two given string columns.
+ *
+ * apache/spark
+ */
+ def levenshtein[T](
+ l: TypedColumn[T, String],
+ r: TypedColumn[T, String]
+ ): TypedColumn[T, Int] =
l.typed(sparkFunctions.levenshtein(l.untyped, r.untyped))
- /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns.
- *
- * apache/spark
- */
- def levenshtein[T](l: TypedAggregate[T, String], r: TypedAggregate[T, String]): TypedAggregate[T, Int] =
+ /**
+ * Non-Aggregate function: Computes the Levenshtein distance of the two given string columns.
+ *
+ * apache/spark
+ */
+ def levenshtein[T](
+ l: TypedAggregate[T, String],
+ r: TypedAggregate[T, String]
+ ): TypedAggregate[T, Int] =
l.typed(sparkFunctions.levenshtein(l.untyped, r.untyped))
- /** Non-Aggregate function: Converts a string column to lower case.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Converts a string column to lower case.
+ *
+ * apache/spark
+ */
def lower[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(sparkFunctions.lower(str.untyped))
- /** Non-Aggregate function: Left-pad the string column with pad to a length of len. If the string column is longer
- * than len, the return value is shortened to len characters.
- *
- * apache/spark
- */
- def lpad[T](str: AbstractTypedColumn[T, String],
- len: Int,
- pad: String): str.ThisType[T, String] =
+ /**
+ * Non-Aggregate function: Left-pad the string column with pad to a length of len. If the string column is longer
+ * than len, the return value is shortened to len characters.
+ *
+ * apache/spark
+ */
+ def lpad[T](
+ str: AbstractTypedColumn[T, String],
+ len: Int,
+ pad: String
+ ): str.ThisType[T, String] =
str.typed(sparkFunctions.lpad(str.untyped, len, pad))
- /** Non-Aggregate function: Trim the spaces from left end for the specified string value.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Trim the spaces from left end for the specified string value.
+ *
+ * apache/spark
+ */
def ltrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(sparkFunctions.ltrim(str.untyped))
- /** Non-Aggregate function: Replace all substrings of the specified string value that match regexp with rep.
- *
- * apache/spark
- */
- def regexpReplace[T](str: AbstractTypedColumn[T, String],
- pattern: Regex,
- replacement: String): str.ThisType[T, String] =
- str.typed(sparkFunctions.regexp_replace(str.untyped, pattern.regex, replacement))
-
+ /**
+ * Non-Aggregate function: Replace all substrings of the specified string value that match regexp with rep.
+ *
+ * apache/spark
+ */
+ def regexpReplace[T](
+ str: AbstractTypedColumn[T, String],
+ pattern: Regex,
+ replacement: String
+ ): str.ThisType[T, String] =
+ str.typed(
+ sparkFunctions.regexp_replace(str.untyped, pattern.regex, replacement)
+ )
- /** Non-Aggregate function: Reverses the string column and returns it as a new string column.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Reverses the string column and returns it as a new string column.
+ *
+ * apache/spark
+ */
def reverse[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(sparkFunctions.reverse(str.untyped))
- /** Non-Aggregate function: Right-pad the string column with pad to a length of len.
- * If the string column is longer than len, the return value is shortened to len characters.
- *
- * apache/spark
- */
- def rpad[T](str: AbstractTypedColumn[T, String], len: Int, pad: String): str.ThisType[T, String] =
+ /**
+ * Non-Aggregate function: Right-pad the string column with pad to a length of len.
+ * If the string column is longer than len, the return value is shortened to len characters.
+ *
+ * apache/spark
+ */
+ def rpad[T](
+ str: AbstractTypedColumn[T, String],
+ len: Int,
+ pad: String
+ ): str.ThisType[T, String] =
str.typed(sparkFunctions.rpad(str.untyped, len, pad))
- /** Non-Aggregate function: Trim the spaces from right end for the specified string value.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Trim the spaces from right end for the specified string value.
+ *
+ * apache/spark
+ */
def rtrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(sparkFunctions.rtrim(str.untyped))
- /** Non-Aggregate function: Substring starts at `pos` and is of length `len`
- *
- * apache/spark
- */
- //TODO: Also for byte array
- def substring[T](str: AbstractTypedColumn[T, String], pos: Int, len: Int): str.ThisType[T, String] =
+ /**
+ * Non-Aggregate function: Substring starts at `pos` and is of length `len`
+ *
+ * apache/spark
+ */
+ // TODO: Also for byte array
+ def substring[T](
+ str: AbstractTypedColumn[T, String],
+ pos: Int,
+ len: Int
+ ): str.ThisType[T, String] =
str.typed(sparkFunctions.substring(str.untyped, pos, len))
- /** Non-Aggregate function: Trim the spaces from both ends for the specified string column.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Trim the spaces from both ends for the specified string column.
+ *
+ * apache/spark
+ */
def trim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(sparkFunctions.trim(str.untyped))
- /** Non-Aggregate function: Converts a string column to upper case.
- *
- * apache/spark
- */
+ /**
+ * Non-Aggregate function: Converts a string column to upper case.
+ *
+ * apache/spark
+ */
def upper[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] =
str.typed(sparkFunctions.upper(str.untyped))
@@ -715,93 +965,123 @@ trait NonAggregateFunctions {
// DateTime functions
//////////////////////////////////////////////////////////////////////////////////////////////
- /** Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#year` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def year[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#year` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def year[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.year(str.untyped))
- /** Non-Aggregate function: Extracts the quarter as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#quarter` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def quarter[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the quarter as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#quarter` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def quarter[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.quarter(str.untyped))
- /** Non-Aggregate function Extracts the month as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#month` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def month[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function Extracts the month as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#month` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def month[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.month(str.untyped))
- /** Non-Aggregate function: Extracts the day of the week as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#dayofweek` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def dayofweek[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the day of the week as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#dayofweek` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def dayofweek[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.dayofweek(str.untyped))
- /** Non-Aggregate function: Extracts the day of the month as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#dayofmonth` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def dayofmonth[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the day of the month as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#dayofmonth` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def dayofmonth[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.dayofmonth(str.untyped))
- /** Non-Aggregate function: Extracts the day of the year as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#dayofyear` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def dayofyear[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the day of the year as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#dayofyear` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def dayofyear[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.dayofyear(str.untyped))
- /** Non-Aggregate function: Extracts the hours as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#hour` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def hour[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the hours as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#hour` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def hour[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.hour(str.untyped))
- /** Non-Aggregate function: Extracts the minutes as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#minute` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def minute[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the minutes as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#minute` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def minute[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.minute(str.untyped))
- /** Non-Aggregate function: Extracts the seconds as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#second` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def second[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the seconds as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#second` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def second[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.second(str.untyped))
- /** Non-Aggregate function: Extracts the week number as an integer from a given date/timestamp/string.
- *
- * Differs from `Column#weekofyear` by wrapping it's result into an `Option`.
- *
- * apache/spark
- */
- def weekofyear[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] =
+ /**
+ * Non-Aggregate function: Extracts the week number as an integer from a given date/timestamp/string.
+ *
+ * Differs from `Column#weekofyear` by wrapping it's result into an `Option`.
+ *
+ * apache/spark
+ */
+ def weekofyear[T](
+ str: AbstractTypedColumn[T, String]
+ ): str.ThisType[T, Option[Int]] =
str.typed(sparkFunctions.weekofyear(str.untyped))
}
diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala
index 93ba7f118..42d65e8ca 100644
--- a/dataset/src/main/scala/frameless/functions/Udf.scala
+++ b/dataset/src/main/scala/frameless/functions/Udf.scala
@@ -2,90 +2,118 @@ package frameless
package functions
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression}
+import org.apache.spark.sql.catalyst.expressions.{
+ Expression,
+ LeafExpression,
+ NonSQLExpression
+}
import org.apache.spark.sql.catalyst.expressions.codegen._
import Block._
import org.apache.spark.sql.types.DataType
import shapeless.syntax.std.tuple._
-/** Documentation marked "apache/spark" is thanks to apache/spark Contributors
- * at https://github.com/apache/spark, licensed under Apache v2.0 available at
- * http://www.apache.org/licenses/LICENSE-2.0
- */
+/**
+ * Documentation marked "apache/spark" is thanks to apache/spark Contributors
+ * at https://github.com/apache/spark, licensed under Apache v2.0 available at
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
trait Udf {
- /** Defines a user-defined function of 1 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the function's signature.
- *
- * apache/spark
- */
- def udf[T, A, R: TypedEncoder](f: A => R):
- TypedColumn[T, A] => TypedColumn[T, R] = {
- u =>
- val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R])
- new TypedColumn[T, R](scalaUdf)
+ /**
+ * Defines a user-defined function of 1 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the function's signature.
+ *
+ * apache/spark
+ */
+ def udf[T, A, R: TypedEncoder](
+ f: A => R
+ ): TypedColumn[T, A] => TypedColumn[T, R] = { u =>
+ val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R])
+ new TypedColumn[T, R](scalaUdf)
}
- /** Defines a user-defined function of 2 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the function's signature.
- *
- * apache/spark
- */
- def udf[T, A1, A2, R: TypedEncoder](f: (A1,A2) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = {
+ /**
+ * Defines a user-defined function of 2 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the function's signature.
+ *
+ * apache/spark
+ */
+ def udf[T, A1, A2, R: TypedEncoder](
+ f: (A1, A2) => R
+ ): (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = {
case us =>
- val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
+ val scalaUdf =
+ FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
- }
+ }
- /** Defines a user-defined function of 3 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the function's signature.
- *
- * apache/spark
- */
- def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1,A2,A3) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = {
+ /**
+ * Defines a user-defined function of 3 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the function's signature.
+ *
+ * apache/spark
+ */
+ def udf[T, A1, A2, A3, R: TypedEncoder](
+ f: (A1, A2, A3) => R
+ ): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = {
case us =>
- val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
+ val scalaUdf =
+ FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
- }
+ }
- /** Defines a user-defined function of 4 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the function's signature.
- *
- * apache/spark
- */
- def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1,A2,A3,A4) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = {
+ /**
+ * Defines a user-defined function of 4 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the function's signature.
+ *
+ * apache/spark
+ */
+ def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): (
+ TypedColumn[T, A1],
+ TypedColumn[T, A2],
+ TypedColumn[T, A3],
+ TypedColumn[T, A4]
+ ) => TypedColumn[T, R] = {
case us =>
- val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
+ val scalaUdf =
+ FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
- }
+ }
- /** Defines a user-defined function of 5 arguments as user-defined function (UDF).
- * The data types are automatically inferred based on the function's signature.
- *
- * apache/spark
- */
- def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1,A2,A3,A4,A5) => R):
- (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = {
+ /**
+ * Defines a user-defined function of 5 arguments as user-defined function (UDF).
+ * The data types are automatically inferred based on the function's signature.
+ *
+ * apache/spark
+ */
+ def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](
+ f: (A1, A2, A3, A4, A5) => R
+ ): (
+ TypedColumn[T, A1],
+ TypedColumn[T, A2],
+ TypedColumn[T, A3],
+ TypedColumn[T, A4],
+ TypedColumn[T, A5]
+ ) => TypedColumn[T, R] = {
case us =>
- val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
+ val scalaUdf =
+ FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R])
new TypedColumn[T, R](scalaUdf)
- }
+ }
}
/**
- * NB: Implementation detail, isn't intended to be directly used.
- *
- * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]].
- */
+ * NB: Implementation detail, isn't intended to be directly used.
+ *
+ * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]].
+ */
case class FramelessUdf[T, R](
- function: AnyRef,
- encoders: Seq[TypedEncoder[_]],
- children: Seq[Expression],
- rencoder: TypedEncoder[R]
-) extends Expression with NonSQLExpression {
+ function: AnyRef,
+ encoders: Seq[TypedEncoder[_]],
+ children: Seq[Expression],
+ rencoder: TypedEncoder[R])
+ extends Expression
+ with NonSQLExpression {
override def nullable: Boolean = rencoder.nullable
override def toString: String = s"FramelessUdf(${children.mkString(", ")})"
@@ -118,10 +146,12 @@ case class FramelessUdf[T, R](
"""
val code = CodeFormatter.stripOverlappingComments(
- new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
+ new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())
+ )
val (clazz, _) = CodeGenerator.compile(code)
- val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef]
+ val codegen =
+ clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef]
codegen
}
@@ -139,29 +169,45 @@ case class FramelessUdf[T, R](
val framelessUdfClassName = classOf[FramelessUdf[_, _]].getName
val funcClassName = s"scala.Function${children.size}"
val funcExpressionIdx = ctx.references.size - 1
- val funcTerm = ctx.addMutableState(funcClassName, ctx.freshName("udf"),
- v => s"$v = ($funcClassName)((($framelessUdfClassName)references" +
- s"[$funcExpressionIdx]).function());")
-
- val (argsCode, funcArguments) = encoders.zip(children).map {
- case (encoder, child) =>
- val eval = child.genCode(ctx)
- val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr)
- val argTerm = ctx.freshName("arg")
- val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));"
+ val funcTerm = ctx.addMutableState(
+ funcClassName,
+ ctx.freshName("udf"),
+ v =>
+ s"$v = ($funcClassName)((($framelessUdfClassName)references" +
+ s"[$funcExpressionIdx]).function());"
+ )
- (convert, argTerm)
- }.unzip
+ val (argsCode, funcArguments) = encoders
+ .zip(children)
+ .map {
+ case (encoder, child) =>
+ val eval = child.genCode(ctx)
+ val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr)
+ val argTerm = ctx.freshName("arg")
+ val convert =
+ s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));"
+
+ (convert, argTerm)
+ }
+ .unzip
val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr)
- val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal"))
- val internalNullTerm = ctx.addMutableState("boolean", ctx.freshName("internalNull"))
+ val internalTerm =
+ ctx.addMutableState(internalTpe, ctx.freshName("internal"))
+ val internalNullTerm =
+ ctx.addMutableState("boolean", ctx.freshName("internalNull"))
// CTw - can't inject the term, may have to duplicate old code for parity
- val internalExpr = Spark2_4_LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr, true)
+ val internalExpr = Spark2_4_LambdaVariable(
+ internalTerm,
+ internalNullTerm,
+ rencoder.jvmRepr,
+ true
+ )
val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx)
- ev.copy(code = code"""
+ ev.copy(
+ code = code"""
${argsCode.mkString("\n")}
$internalTerm =
@@ -175,21 +221,28 @@ case class FramelessUdf[T, R](
)
}
- protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren)
+ protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]
+ ): Expression = copy(children = newChildren)
}
case class Spark2_4_LambdaVariable(
- value: String,
- isNull: String,
- dataType: DataType,
- nullable: Boolean = true) extends LeafExpression with NonSQLExpression {
+ value: String,
+ isNull: String,
+ dataType: DataType,
+ nullable: Boolean = true)
+ extends LeafExpression
+ with NonSQLExpression {
- private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
+ private val accessor: (InternalRow, Int) => Any =
+ InternalRow.getAccessor(dataType)
// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
override def eval(input: InternalRow): Any = {
- assert(input.numFields == 1,
- "The input row of interpreted LambdaVariable should have only 1 field.")
+ assert(
+ input.numFields == 1,
+ "The input row of interpreted LambdaVariable should have only 1 field."
+ )
if (nullable && input.isNullAt(0)) {
null
} else {
@@ -197,7 +250,10 @@ case class Spark2_4_LambdaVariable(
}
}
- override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ override protected def doGenCode(
+ ctx: CodegenContext,
+ ev: ExprCode
+ ): ExprCode = {
val isNullValue = if (nullable) {
JavaCode.isNullVariable(isNull)
} else {
@@ -208,12 +264,13 @@ case class Spark2_4_LambdaVariable(
}
object FramelessUdf {
+
// Spark needs case class with `children` field to mutate it
def apply[T, R](
- function: AnyRef,
- cols: Seq[UntypedExpression[T]],
- rencoder: TypedEncoder[R]
- ): FramelessUdf[T, R] = FramelessUdf(
+ function: AnyRef,
+ cols: Seq[UntypedExpression[T]],
+ rencoder: TypedEncoder[R]
+ ): FramelessUdf[T, R] = FramelessUdf(
function = function,
encoders = cols.map(_.uencoder).toList,
children = cols.map(x => x.uencoder.fromCatalyst(x.expr)).toList,
diff --git a/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala b/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala
index 64bdf0ed1..6a9f41cdd 100644
--- a/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala
+++ b/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala
@@ -1,50 +1,74 @@
package frameless
package functions
-import org.apache.spark.sql.{Column, functions => sparkFunctions}
+import org.apache.spark.sql.{ Column, functions => sparkFunctions }
import scala.math.Ordering
trait UnaryFunctions {
- /** Returns length of array
- *
- * apache/spark
- */
- def size[T, A, V[_] : CatalystSizableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, Int] =
- new TypedColumn[T, Int](implicitly[CatalystSizableCollection[V]].sizeOp(column.untyped))
-
- /** Returns length of Map
- *
- * apache/spark
- */
+
+ /**
+ * Returns length of array
+ *
+ * apache/spark
+ */
+ def size[T, A, V[_]: CatalystSizableCollection](
+ column: TypedColumn[T, V[A]]
+ ): TypedColumn[T, Int] =
+ new TypedColumn[T, Int](
+ implicitly[CatalystSizableCollection[V]].sizeOp(column.untyped)
+ )
+
+ /**
+ * Returns length of Map
+ *
+ * apache/spark
+ */
def size[T, A, B](column: TypedColumn[T, Map[A, B]]): TypedColumn[T, Int] =
new TypedColumn[T, Int](sparkFunctions.size(column.untyped))
- /** Sorts the input array for the given column in ascending order, according to
- * the natural ordering of the array elements.
- *
- * apache/spark
- */
- def sortAscending[T, A: Ordering, V[_] : CatalystSortableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, V[A]] =
- new TypedColumn[T, V[A]](implicitly[CatalystSortableCollection[V]].sortOp(column.untyped, sortAscending = true))(column.uencoder)
-
- /** Sorts the input array for the given column in descending order, according to
- * the natural ordering of the array elements.
- *
- * apache/spark
- */
- def sortDescending[T, A: Ordering, V[_] : CatalystSortableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, V[A]] =
- new TypedColumn[T, V[A]](implicitly[CatalystSortableCollection[V]].sortOp(column.untyped, sortAscending = false))(column.uencoder)
-
-
- /** Creates a new row for each element in the given collection. The column types
- * eligible for this operation are constrained by CatalystExplodableCollection.
- *
- * apache/spark
- */
- @deprecated("Use explode() from the TypedDataset instead. This method will result in " +
- "runtime error if applied to two columns in the same select statement.", "0.6.2")
- def explode[T, A: TypedEncoder, V[_] : CatalystExplodableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, A] =
+ /**
+ * Sorts the input array for the given column in ascending order, according to
+ * the natural ordering of the array elements.
+ *
+ * apache/spark
+ */
+ def sortAscending[T, A: Ordering, V[_]: CatalystSortableCollection](
+ column: TypedColumn[T, V[A]]
+ ): TypedColumn[T, V[A]] =
+ new TypedColumn[T, V[A]](
+ implicitly[CatalystSortableCollection[V]]
+ .sortOp(column.untyped, sortAscending = true)
+ )(column.uencoder)
+
+ /**
+ * Sorts the input array for the given column in descending order, according to
+ * the natural ordering of the array elements.
+ *
+ * apache/spark
+ */
+ def sortDescending[T, A: Ordering, V[_]: CatalystSortableCollection](
+ column: TypedColumn[T, V[A]]
+ ): TypedColumn[T, V[A]] =
+ new TypedColumn[T, V[A]](
+ implicitly[CatalystSortableCollection[V]]
+ .sortOp(column.untyped, sortAscending = false)
+ )(column.uencoder)
+
+ /**
+ * Creates a new row for each element in the given collection. The column types
+ * eligible for this operation are constrained by CatalystExplodableCollection.
+ *
+ * apache/spark
+ */
+ @deprecated(
+ "Use explode() from the TypedDataset instead. This method will result in " +
+ "runtime error if applied to two columns in the same select statement.",
+ "0.6.2"
+ )
+ def explode[T, A: TypedEncoder, V[_]: CatalystExplodableCollection](
+ column: TypedColumn[T, V[A]]
+ ): TypedColumn[T, A] =
new TypedColumn[T, A](sparkFunctions.explode(column.untyped))
}
@@ -53,27 +77,39 @@ trait CatalystSizableCollection[V[_]] {
}
object CatalystSizableCollection {
- implicit def sizableVector: CatalystSizableCollection[Vector] = new CatalystSizableCollection[Vector] {
- def sizeOp(col: Column): Column = sparkFunctions.size(col)
- }
- implicit def sizableArray: CatalystSizableCollection[Array] = new CatalystSizableCollection[Array] {
- def sizeOp(col: Column): Column = sparkFunctions.size(col)
- }
+ implicit def sizableVector: CatalystSizableCollection[Vector] =
+ new CatalystSizableCollection[Vector] {
+ def sizeOp(col: Column): Column = sparkFunctions.size(col)
+ }
+
+ implicit def sizableArray: CatalystSizableCollection[Array] =
+ new CatalystSizableCollection[Array] {
+ def sizeOp(col: Column): Column = sparkFunctions.size(col)
+ }
- implicit def sizableList: CatalystSizableCollection[List] = new CatalystSizableCollection[List] {
- def sizeOp(col: Column): Column = sparkFunctions.size(col)
- }
+ implicit def sizableList: CatalystSizableCollection[List] =
+ new CatalystSizableCollection[List] {
+ def sizeOp(col: Column): Column = sparkFunctions.size(col)
+ }
}
trait CatalystExplodableCollection[V[_]]
object CatalystExplodableCollection {
- implicit def explodableVector: CatalystExplodableCollection[Vector] = new CatalystExplodableCollection[Vector] {}
- implicit def explodableArray: CatalystExplodableCollection[Array] = new CatalystExplodableCollection[Array] {}
- implicit def explodableList: CatalystExplodableCollection[List] = new CatalystExplodableCollection[List] {}
- implicit def explodableSeq: CatalystExplodableCollection[Seq] = new CatalystExplodableCollection[Seq] {}
+
+ implicit def explodableVector: CatalystExplodableCollection[Vector] =
+ new CatalystExplodableCollection[Vector] {}
+
+ implicit def explodableArray: CatalystExplodableCollection[Array] =
+ new CatalystExplodableCollection[Array] {}
+
+ implicit def explodableList: CatalystExplodableCollection[List] =
+ new CatalystExplodableCollection[List] {}
+
+ implicit def explodableSeq: CatalystExplodableCollection[Seq] =
+ new CatalystExplodableCollection[Seq] {}
}
trait CatalystSortableCollection[V[_]] {
@@ -81,15 +117,22 @@ trait CatalystSortableCollection[V[_]] {
}
object CatalystSortableCollection {
- implicit def sortableVector: CatalystSortableCollection[Vector] = new CatalystSortableCollection[Vector] {
- def sortOp(col: Column, sortAscending: Boolean): Column = sparkFunctions.sort_array(col, sortAscending)
- }
-
- implicit def sortableArray: CatalystSortableCollection[Array] = new CatalystSortableCollection[Array] {
- def sortOp(col: Column, sortAscending: Boolean): Column = sparkFunctions.sort_array(col, sortAscending)
- }
- implicit def sortableList: CatalystSortableCollection[List] = new CatalystSortableCollection[List] {
- def sortOp(col: Column, sortAscending: Boolean): Column = sparkFunctions.sort_array(col, sortAscending)
- }
+ implicit def sortableVector: CatalystSortableCollection[Vector] =
+ new CatalystSortableCollection[Vector] {
+ def sortOp(col: Column, sortAscending: Boolean): Column =
+ sparkFunctions.sort_array(col, sortAscending)
+ }
+
+ implicit def sortableArray: CatalystSortableCollection[Array] =
+ new CatalystSortableCollection[Array] {
+ def sortOp(col: Column, sortAscending: Boolean): Column =
+ sparkFunctions.sort_array(col, sortAscending)
+ }
+
+ implicit def sortableList: CatalystSortableCollection[List] =
+ new CatalystSortableCollection[List] {
+ def sortOp(col: Column, sortAscending: Boolean): Column =
+ sparkFunctions.sort_array(col, sortAscending)
+ }
}
diff --git a/dataset/src/main/scala/frameless/ops/AggregateTypes.scala b/dataset/src/main/scala/frameless/ops/AggregateTypes.scala
index 403c25301..58d8cb27f 100644
--- a/dataset/src/main/scala/frameless/ops/AggregateTypes.scala
+++ b/dataset/src/main/scala/frameless/ops/AggregateTypes.scala
@@ -3,26 +3,32 @@ package ops
import shapeless._
-/** A type class to extract the column types out of an HList of [[frameless.TypedAggregate]].
- *
- * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped.
- * @example
- * {{{
- * type U = TypedAggregate[T,A] :: TypedAggregate[T,B] :: TypedAggregate[T,C] :: HNil
- * type Out = A :: B :: C :: HNil
- * }}}
- */
+/**
+ * A type class to extract the column types out of an HList of [[frameless.TypedAggregate]].
+ *
+ * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped.
+ * @example
+ * {{{
+ * type U = TypedAggregate[T,A] :: TypedAggregate[T,B] :: TypedAggregate[T,C] :: HNil
+ * type Out = A :: B :: C :: HNil
+ * }}}
+ */
trait AggregateTypes[V, U <: HList] {
type Out <: HList
}
object AggregateTypes {
- type Aux[V, U <: HList, Out0 <: HList] = AggregateTypes[V, U] {type Out = Out0}
- implicit def deriveHNil[T]: AggregateTypes.Aux[T, HNil, HNil] = new AggregateTypes[T, HNil] { type Out = HNil }
+ type Aux[V, U <: HList, Out0 <: HList] = AggregateTypes[V, U] {
+ type Out = Out0
+ }
+
+ implicit def deriveHNil[T]: AggregateTypes.Aux[T, HNil, HNil] =
+ new AggregateTypes[T, HNil] { type Out = HNil }
implicit def deriveCons1[T, H, TT <: HList, V <: HList](
- implicit tail: AggregateTypes.Aux[T, TT, V]
- ): AggregateTypes.Aux[T, TypedAggregate[T, H] :: TT, H :: V] =
- new AggregateTypes[T, TypedAggregate[T, H] :: TT] {type Out = H :: V}
+ implicit
+ tail: AggregateTypes.Aux[T, TT, V]
+ ): AggregateTypes.Aux[T, TypedAggregate[T, H] :: TT, H :: V] =
+ new AggregateTypes[T, TypedAggregate[T, H] :: TT] { type Out = H :: V }
}
diff --git a/dataset/src/main/scala/frameless/ops/As.scala b/dataset/src/main/scala/frameless/ops/As.scala
index 06b691028..e9c553f72 100644
--- a/dataset/src/main/scala/frameless/ops/As.scala
+++ b/dataset/src/main/scala/frameless/ops/As.scala
@@ -1,10 +1,12 @@
package frameless
package ops
-import shapeless.{::, Generic, HList, Lazy}
+import shapeless.{ ::, Generic, HList, Lazy }
/** Evidence for correctness of `TypedDataset[T].as[U]` */
-class As[T, U] private (implicit val encoder: TypedEncoder[U])
+class As[T, U] private (
+ implicit
+ val encoder: TypedEncoder[U])
object As extends LowPriorityAs {
@@ -12,8 +14,8 @@ object As extends LowPriorityAs {
implicit def equivIdentity[A] = new Equiv[A, A]
- implicit def deriveAs[A, B]
- (implicit
+ implicit def deriveAs[A, B](
+ implicit
i0: TypedEncoder[B],
i1: Equiv[A, B]
): As[A, B] = new As[A, B]
@@ -24,14 +26,14 @@ trait LowPriorityAs {
import As.Equiv
- implicit def equivHList[AH, AT <: HList, BH, BT <: HList]
- (implicit
+ implicit def equivHList[AH, AT <: HList, BH, BT <: HList](
+ implicit
i0: Lazy[Equiv[AH, BH]],
i1: Equiv[AT, BT]
): Equiv[AH :: AT, BH :: BT] = new Equiv[AH :: AT, BH :: BT]
- implicit def equivGeneric[A, B, R, S]
- (implicit
+ implicit def equivGeneric[A, B, R, S](
+ implicit
i0: Generic.Aux[A, R],
i1: Generic.Aux[B, S],
i2: Lazy[Equiv[R, S]]
diff --git a/dataset/src/main/scala/frameless/ops/ColumnTypes.scala b/dataset/src/main/scala/frameless/ops/ColumnTypes.scala
index e5ae6aea2..923b0a168 100644
--- a/dataset/src/main/scala/frameless/ops/ColumnTypes.scala
+++ b/dataset/src/main/scala/frameless/ops/ColumnTypes.scala
@@ -3,26 +3,29 @@ package ops
import shapeless._
-/** A type class to extract the column types out of an HList of [[frameless.TypedColumn]].
- *
- * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped.
- * @example
- * {{{
- * type U = TypedColumn[T,A] :: TypedColumn[T,B] :: TypedColumn[T,C] :: HNil
- * type Out = A :: B :: C :: HNil
- * }}}
- */
+/**
+ * A type class to extract the column types out of an HList of [[frameless.TypedColumn]].
+ *
+ * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped.
+ * @example
+ * {{{
+ * type U = TypedColumn[T,A] :: TypedColumn[T,B] :: TypedColumn[T,C] :: HNil
+ * type Out = A :: B :: C :: HNil
+ * }}}
+ */
trait ColumnTypes[T, U <: HList] {
type Out <: HList
}
object ColumnTypes {
- type Aux[T, U <: HList, Out0 <: HList] = ColumnTypes[T, U] {type Out = Out0}
+ type Aux[T, U <: HList, Out0 <: HList] = ColumnTypes[T, U] { type Out = Out0 }
- implicit def deriveHNil[T]: ColumnTypes.Aux[T, HNil, HNil] = new ColumnTypes[T, HNil] { type Out = HNil }
+ implicit def deriveHNil[T]: ColumnTypes.Aux[T, HNil, HNil] =
+ new ColumnTypes[T, HNil] { type Out = HNil }
implicit def deriveCons[T, H, TT <: HList, V <: HList](
- implicit tail: ColumnTypes.Aux[T, TT, V]
- ): ColumnTypes.Aux[T, TypedColumn[T, H] :: TT, H :: V] =
- new ColumnTypes[T, TypedColumn[T, H] :: TT] {type Out = H :: V}
+ implicit
+ tail: ColumnTypes.Aux[T, TT, V]
+ ): ColumnTypes.Aux[T, TypedColumn[T, H] :: TT, H :: V] =
+ new ColumnTypes[T, TypedColumn[T, H] :: TT] { type Out = H :: V }
}
diff --git a/dataset/src/main/scala/frameless/ops/GroupByOps.scala b/dataset/src/main/scala/frameless/ops/GroupByOps.scala
index 3feeaca59..e4425e8f3 100644
--- a/dataset/src/main/scala/frameless/ops/GroupByOps.scala
+++ b/dataset/src/main/scala/frameless/ops/GroupByOps.scala
@@ -3,36 +3,54 @@ package ops
import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.plans.logical.Project
-import org.apache.spark.sql.{Column, Dataset, FramelessInternals, RelationalGroupedDataset}
+import org.apache.spark.sql.{
+ Column,
+ Dataset,
+ FramelessInternals,
+ RelationalGroupedDataset
+}
import shapeless._
-import shapeless.ops.hlist.{Length, Mapped, Prepend, ToList, ToTraversable, Tupler}
+import shapeless.ops.hlist.{
+ Length,
+ Mapped,
+ Prepend,
+ ToList,
+ ToTraversable,
+ Tupler
+}
-class GroupedByManyOps[T, TK <: HList, K <: HList, KT]
- (self: TypedDataset[T], groupedBy: TK)
- (implicit
+class GroupedByManyOps[T, TK <: HList, K <: HList, KT](
+ self: TypedDataset[T],
+ groupedBy: TK
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: ToTraversable.Aux[TK, List, UntypedExpression[T]],
- i3: Tupler.Aux[K, KT]
- ) extends AggregatingOps[T, TK, K, KT](self, groupedBy, (dataset, cols) => dataset.groupBy(cols: _*)) {
+ i3: Tupler.Aux[K, KT])
+ extends AggregatingOps[T, TK, K, KT](
+ self,
+ groupedBy,
+ (dataset, cols) => dataset.groupBy(cols: _*)
+ ) {
+
object agg extends ProductArgs {
- def applyProduct[TC <: HList, C <: HList, Out0 <: HList, Out1]
- (columns: TC)
- (implicit
+
+ def applyProduct[TC <: HList, C <: HList, Out0 <: HList, Out1](
+ columns: TC
+ )(implicit
i3: AggregateTypes.Aux[T, TC, C],
i4: Prepend.Aux[K, C, Out0],
i5: Tupler.Aux[Out0, Out1],
i6: TypedEncoder[Out1],
i7: ToTraversable.Aux[TC, List, UntypedExpression[T]]
): TypedDataset[Out1] = {
- aggregate[TC, Out1](columns)
- }
+ aggregate[TC, Out1](columns)
+ }
}
}
class GroupedBy1Ops[K1, V](
- self: TypedDataset[V],
- g1: TypedColumn[V, K1]
-) {
+ self: TypedDataset[V],
+ g1: TypedColumn[V, K1]) {
private def underlying = new GroupedByManyOps(self, g1 :: HNil)
private implicit def eg1 = g1.uencoder
@@ -41,49 +59,77 @@ class GroupedBy1Ops[K1, V](
underlying.agg(c1)
}
- def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(K1, U1, U2)] = {
+ def agg[U1, U2](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2]
+ ): TypedDataset[(K1, U1, U2)] = {
implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder
underlying.agg(c1, c2)
}
- def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(K1, U1, U2, U3)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder
+ def agg[U1, U2, U3](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3]
+ ): TypedDataset[(K1, U1, U2, U3)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder
underlying.agg(c1, c2, c3)
}
- def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(K1, U1, U2, U3, U4)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
+ def agg[U1, U2, U3, U4](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4]
+ ): TypedDataset[(K1, U1, U2, U3, U4)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
underlying.agg(c1, c2, c3, c4)
}
- def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(K1, U1, U2, U3, U4, U5)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder
+ def agg[U1, U2, U3, U4, U5](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4],
+ c5: TypedAggregate[V, U5]
+ ): TypedDataset[(K1, U1, U2, U3, U4, U5)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder;
+ implicit val e5 = c5.uencoder
underlying.agg(c1, c2, c3, c4, c5)
}
- /** Methods on `TypedDataset[T]` that go through a full serialization and
- * deserialization of `T`, and execute outside of the Catalyst runtime.
- */
+ /**
+ * Methods on `TypedDataset[T]` that go through a full serialization and
+ * deserialization of `T`, and execute outside of the Catalyst runtime.
+ */
object deserialized {
- def mapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => U): TypedDataset[U] = {
+
+ def mapGroups[U: TypedEncoder](
+ f: (K1, Iterator[V]) => U
+ ): TypedDataset[U] = {
underlying.deserialized.mapGroups(AggregatingOps.tuple1(f))
}
- def flatMapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = {
+ def flatMapGroups[U: TypedEncoder](
+ f: (K1, Iterator[V]) => TraversableOnce[U]
+ ): TypedDataset[U] = {
underlying.deserialized.flatMapGroups(AggregatingOps.tuple1(f))
}
}
- def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): PivotNotValues[V, TypedColumn[V,K1] :: HNil, P] =
+ def pivot[P: CatalystPivotable](
+ pivotColumn: TypedColumn[V, P]
+ ): PivotNotValues[V, TypedColumn[V, K1] :: HNil, P] =
PivotNotValues(self, g1 :: HNil, pivotColumn)
}
-
class GroupedBy2Ops[K1, K2, V](
- self: TypedDataset[V],
- g1: TypedColumn[V, K1],
- g2: TypedColumn[V, K2]
-) {
+ self: TypedDataset[V],
+ g1: TypedColumn[V, K1],
+ g2: TypedColumn[V, K2]) {
private def underlying = new GroupedByManyOps(self, g1 :: g2 :: HNil)
private implicit def eg1 = g1.uencoder
private implicit def eg2 = g2.uencoder
@@ -93,57 +139,88 @@ class GroupedBy2Ops[K1, K2, V](
underlying.agg(c1)
}
- def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(K1, K2, U1, U2)] = {
+ def agg[U1, U2](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2]
+ ): TypedDataset[(K1, K2, U1, U2)] = {
implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder
underlying.agg(c1, c2)
}
- def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(K1, K2, U1, U2, U3)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder
+ def agg[U1, U2, U3](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3]
+ ): TypedDataset[(K1, K2, U1, U2, U3)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder
underlying.agg(c1, c2, c3)
}
- def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(K1, K2, U1, U2, U3, U4)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
- underlying.agg(c1 , c2 , c3 , c4)
+ def agg[U1, U2, U3, U4](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4]
+ ): TypedDataset[(K1, K2, U1, U2, U3, U4)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
+ underlying.agg(c1, c2, c3, c4)
}
- def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(K1, K2, U1, U2, U3, U4, U5)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder
+ def agg[U1, U2, U3, U4, U5](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4],
+ c5: TypedAggregate[V, U5]
+ ): TypedDataset[(K1, K2, U1, U2, U3, U4, U5)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder;
+ implicit val e5 = c5.uencoder
underlying.agg(c1, c2, c3, c4, c5)
}
-
- /** Methods on `TypedDataset[T]` that go through a full serialization and
- * deserialization of `T`, and execute outside of the Catalyst runtime.
- */
+ /**
+ * Methods on `TypedDataset[T]` that go through a full serialization and
+ * deserialization of `T`, and execute outside of the Catalyst runtime.
+ */
object deserialized {
- def mapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => U): TypedDataset[U] = {
+
+ def mapGroups[U: TypedEncoder](
+ f: ((K1, K2), Iterator[V]) => U
+ ): TypedDataset[U] = {
underlying.deserialized.mapGroups(f)
}
- def flatMapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = {
+ def flatMapGroups[U: TypedEncoder](
+ f: ((K1, K2), Iterator[V]) => TraversableOnce[U]
+ ): TypedDataset[U] = {
underlying.deserialized.flatMapGroups(f)
}
}
- def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]):
- PivotNotValues[V, TypedColumn[V,K1] :: TypedColumn[V, K2] :: HNil, P] =
- PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn)
+ def pivot[P: CatalystPivotable](
+ pivotColumn: TypedColumn[V, P]
+ ): PivotNotValues[V, TypedColumn[V, K1] :: TypedColumn[V, K2] :: HNil, P] =
+ PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn)
}
-private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT]
- (self: TypedDataset[T], groupedBy: TK, groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset)
- (implicit
+private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT](
+ self: TypedDataset[T],
+ groupedBy: TK,
+ groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: ToTraversable.Aux[TK, List, UntypedExpression[T]],
- i2: Tupler.Aux[K, KT]
- ) {
- def aggregate[TC <: HList, Out1](columns: TC)
- (implicit
- i7: TypedEncoder[Out1],
- i8: ToTraversable.Aux[TC, List, UntypedExpression[T]]
- ): TypedDataset[Out1] = {
+ i2: Tupler.Aux[K, KT]) {
+
+ def aggregate[TC <: HList, Out1](
+ columns: TC
+ )(implicit
+ i7: TypedEncoder[Out1],
+ i8: ToTraversable.Aux[TC, List, UntypedExpression[T]]
+ ): TypedDataset[Out1] = {
def expr(c: UntypedExpression[T]): Column = new Column(c.expr)
val groupByExprs = groupedBy.toList[UntypedExpression[T]].map(expr)
@@ -159,25 +236,32 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT]
TypedDataset.create[Out1](aggregated)
}
- /** Methods on `TypedDataset[T]` that go through a full serialization and
- * deserialization of `T`, and execute outside of the Catalyst runtime.
- */
+ /**
+ * Methods on `TypedDataset[T]` that go through a full serialization and
+ * deserialization of `T`, and execute outside of the Catalyst runtime.
+ */
object deserialized {
+
def mapGroups[U: TypedEncoder](
- f: (KT, Iterator[T]) => U
- )(implicit e: TypedEncoder[KT]): TypedDataset[U] = {
+ f: (KT, Iterator[T]) => U
+ )(implicit
+ e: TypedEncoder[KT]
+ ): TypedDataset[U] = {
val func = (key: KT, it: Iterator[T]) => Iterator(f(key, it))
flatMapGroups(func)
}
def flatMapGroups[U: TypedEncoder](
- f: (KT, Iterator[T]) => TraversableOnce[U]
- )(implicit e: TypedEncoder[KT]): TypedDataset[U] = {
+ f: (KT, Iterator[T]) => TraversableOnce[U]
+ )(implicit
+ e: TypedEncoder[KT]
+ ): TypedDataset[U] = {
implicit val tendcoder = self.encoder
val cols = groupedBy.toList[UntypedExpression[T]]
val logicalPlan = FramelessInternals.logicalPlan(self.dataset)
- val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
+ val withKeyColumns =
+ logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_))
val withKey = Project(withKeyColumns, logicalPlan)
val executed = FramelessInternals.executePlan(self.dataset, withKey)
val keyAttributes = executed.analyzed.output.takeRight(cols.size)
@@ -188,7 +272,11 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT]
keyAttributes,
dataAttributes,
executed.analyzed
- )(TypedExpressionEncoder[KT], TypedExpressionEncoder[T], TypedExpressionEncoder[U])
+ )(
+ TypedExpressionEncoder[KT],
+ TypedExpressionEncoder[T],
+ TypedExpressionEncoder[U]
+ )
val groupedAndFlatMapped = FramelessInternals.mkDataset(
self.dataset.sqlContext,
@@ -201,66 +289,97 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT]
}
private def retainGroupColumns: Boolean = {
- self.dataset.sqlContext.getConf("spark.sql.retainGroupColumns", "true").toBoolean
+ self.dataset.sqlContext
+ .getConf("spark.sql.retainGroupColumns", "true")
+ .toBoolean
}
- def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[T, P]): PivotNotValues[T, TK, P] =
+ def pivot[P: CatalystPivotable](
+ pivotColumn: TypedColumn[T, P]
+ ): PivotNotValues[T, TK, P] =
PivotNotValues(self, groupedBy, pivotColumn)
}
private[ops] object AggregatingOps {
+
/** Utility function to help Spark with serialization of closures */
- def tuple1[K1, V, U](f: (K1, Iterator[V]) => U): (Tuple1[K1], Iterator[V]) => U = {
- (x: Tuple1[K1], it: Iterator[V]) => f(x._1, it)
+ def tuple1[K1, V, U](
+ f: (K1, Iterator[V]) => U
+ ): (Tuple1[K1], Iterator[V]) => U = { (x: Tuple1[K1], it: Iterator[V]) =>
+ f(x._1, it)
}
}
-/** Represents a typed Pivot operation.
- */
+/**
+ * Represents a typed Pivot operation.
+ */
final case class Pivot[T, GroupedColumns <: HList, PivotType, Values <: HList](
- ds: TypedDataset[T],
- groupedBy: GroupedColumns,
- pivotedBy: TypedColumn[T, PivotType],
- values: Values
-) {
+ ds: TypedDataset[T],
+ groupedBy: GroupedColumns,
+ pivotedBy: TypedColumn[T, PivotType],
+ values: Values) {
object agg extends ProductArgs {
- def applyProduct[AggrColumns <: HList, AggrColumnTypes <: HList, GroupedColumnTypes <: HList, NumValues <: Nat, TypesForPivotedValues <: HList, TypesForPivotedValuesOpt <: HList, OutAsHList <: HList, Out]
- (aggrColumns: AggrColumns)
- (implicit
+
+ def applyProduct[
+ AggrColumns <: HList,
+ AggrColumnTypes <: HList,
+ GroupedColumnTypes <: HList,
+ NumValues <: Nat,
+ TypesForPivotedValues <: HList,
+ TypesForPivotedValuesOpt <: HList,
+ OutAsHList <: HList,
+ Out
+ ](aggrColumns: AggrColumns
+ )(implicit
i0: AggregateTypes.Aux[T, AggrColumns, AggrColumnTypes],
i1: ColumnTypes.Aux[T, GroupedColumns, GroupedColumnTypes],
i2: Length.Aux[Values, NumValues],
i3: Repeat.Aux[AggrColumnTypes, NumValues, TypesForPivotedValues],
i4: Mapped.Aux[TypesForPivotedValues, Option, TypesForPivotedValuesOpt],
- i5: Prepend.Aux[GroupedColumnTypes, TypesForPivotedValuesOpt, OutAsHList],
+ i5: Prepend.Aux[
+ GroupedColumnTypes,
+ TypesForPivotedValuesOpt,
+ OutAsHList
+ ],
i6: Tupler.Aux[OutAsHList, Out],
i7: TypedEncoder[Out]
): TypedDataset[Out] = {
- def mapAny[X](h: HList)(f: Any => X): List[X] =
- h match {
- case HNil => Nil
- case x :: xs => f(x) :: mapAny(xs)(f)
- }
-
- val aggCols: Seq[Column] = mapAny(aggrColumns)(x => new Column(x.asInstanceOf[TypedAggregate[_,_]].expr))
- val tmp = ds.dataset.toDF()
- .groupBy(mapAny(groupedBy)(_.asInstanceOf[TypedColumn[_, _]].untyped): _*)
- .pivot(pivotedBy.untyped.toString, mapAny(values)(identity))
- .agg(aggCols.head, aggCols.tail:_*)
- .as[Out](TypedExpressionEncoder[Out])
- TypedDataset.create(tmp)
- }
+ def mapAny[X](h: HList)(f: Any => X): List[X] =
+ h match {
+ case HNil => Nil
+ case x :: xs => f(x) :: mapAny(xs)(f)
+ }
+
+ val aggCols: Seq[Column] = mapAny(aggrColumns)(x =>
+ new Column(x.asInstanceOf[TypedAggregate[_, _]].expr)
+ )
+ val tmp = ds.dataset
+ .toDF()
+ .groupBy(
+ mapAny(groupedBy)(_.asInstanceOf[TypedColumn[_, _]].untyped): _*
+ )
+ .pivot(pivotedBy.untyped.toString, mapAny(values)(identity))
+ .agg(aggCols.head, aggCols.tail: _*)
+ .as[Out](TypedExpressionEncoder[Out])
+ TypedDataset.create(tmp)
+ }
}
}
final case class PivotNotValues[T, GroupedColumns <: HList, PivotType](
- ds: TypedDataset[T],
- groupedBy: GroupedColumns,
- pivotedBy: TypedColumn[T, PivotType]
-) extends ProductArgs {
-
- def onProduct[Values <: HList](values: Values)(
- implicit validValues: ToList[Values, PivotType] // validValues: FilterNot.Aux[Values, PivotType, HNil] // did not work
- ): Pivot[T, GroupedColumns, PivotType, Values] = Pivot(ds, groupedBy, pivotedBy, values)
+ ds: TypedDataset[T],
+ groupedBy: GroupedColumns,
+ pivotedBy: TypedColumn[T, PivotType])
+ extends ProductArgs {
+
+ def onProduct[Values <: HList](
+ values: Values
+ )(implicit
+ validValues: ToList[
+ Values,
+ PivotType
+ ] // validValues: FilterNot.Aux[Values, PivotType, HNil] // did not work
+ ): Pivot[T, GroupedColumns, PivotType, Values] =
+ Pivot(ds, groupedBy, pivotedBy, values)
}
diff --git a/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala b/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala
index 569407762..bfffdb1fb 100644
--- a/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala
+++ b/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala
@@ -1,49 +1,74 @@
package frameless
package ops
-import org.apache.spark.sql.{Column, Dataset, RelationalGroupedDataset}
-import shapeless.ops.hlist.{Mapped, Prepend, ToTraversable, Tupler}
-import shapeless.{::, HList, HNil, ProductArgs}
+import org.apache.spark.sql.{ Column, Dataset, RelationalGroupedDataset }
+import shapeless.ops.hlist.{ Mapped, Prepend, ToTraversable, Tupler }
+import shapeless.{ ::, HList, HNil, ProductArgs }
/**
- * @param groupingFunc functions used to group elements, can be cube or rollup
- * @tparam T the original `TypedDataset's` type T
- * @tparam TK all columns chosen for aggregation
- * @tparam K individual columns' types as HList
- * @tparam KT individual columns' types as Tuple
- */
-private[ops] abstract class RelationalGroupsOps[T, TK <: HList, K <: HList, KT]
- (self: TypedDataset[T], groupedBy: TK, groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset)
- (implicit
+ * @param groupingFunc functions used to group elements, can be cube or rollup
+ * @tparam T the original `TypedDataset's` type T
+ * @tparam TK all columns chosen for aggregation
+ * @tparam K individual columns' types as HList
+ * @tparam KT individual columns' types as Tuple
+ */
+private[ops] abstract class RelationalGroupsOps[T, TK <: HList, K <: HList, KT](
+ self: TypedDataset[T],
+ groupedBy: TK,
+ groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: ToTraversable.Aux[TK, List, UntypedExpression[T]],
- i2: Tupler.Aux[K, KT]
- ) extends AggregatingOps(self, groupedBy, groupingFunc){
+ i2: Tupler.Aux[K, KT])
+ extends AggregatingOps(self, groupedBy, groupingFunc) {
object agg extends ProductArgs {
+
/**
- * @tparam TC resulting columns after aggregation function
- * @tparam C individual columns' types as HList
- * @tparam OptK columns' types mapped to Option
- * @tparam Out0 OptK columns appended to C
- * @tparam Out1 output type
- */
- def applyProduct[TC <: HList, C <: HList, OptK <: HList, Out0 <: HList, Out1]
- (columns: TC)
- (implicit
- i3: AggregateTypes.Aux[T, TC, C], // shares individual columns' types after agg function as HList
- i4: Mapped.Aux[K, Option, OptK], // maps all original columns' types to Option
- i5: Prepend.Aux[OptK, C, Out0], // concatenates Option columns with those resulting from applying agg function
- i6: Tupler.Aux[Out0, Out1], // converts resulting HList into Tuple for output type
- i7: TypedEncoder[Out1], // proof that there is `TypedEncoder` for the output type
- i8: ToTraversable.Aux[TC, List, UntypedExpression[T]] // allows converting this HList to ordinary List
- ): TypedDataset[Out1] = {
+ * @tparam TC resulting columns after aggregation function
+ * @tparam C individual columns' types as HList
+ * @tparam OptK columns' types mapped to Option
+ * @tparam Out0 OptK columns appended to C
+ * @tparam Out1 output type
+ */
+ def applyProduct[
+ TC <: HList,
+ C <: HList,
+ OptK <: HList,
+ Out0 <: HList,
+ Out1
+ ](columns: TC
+ )(implicit
+ i3: AggregateTypes.Aux[
+ T,
+ TC,
+ C
+ ], // shares individual columns' types after agg function as HList
+ i4: Mapped.Aux[
+ K,
+ Option,
+ OptK
+ ], // maps all original columns' types to Option
+ i5: Prepend.Aux[OptK, C, Out0], // concatenates Option columns with those resulting from applying agg function
+ i6: Tupler.Aux[
+ Out0,
+ Out1
+ ], // converts resulting HList into Tuple for output type
+ i7: TypedEncoder[
+ Out1
+ ], // proof that there is `TypedEncoder` for the output type
+ i8: ToTraversable.Aux[TC, List, UntypedExpression[
+ T
+ ]] // allows converting this HList to ordinary List
+ ): TypedDataset[Out1] = {
aggregate[TC, Out1](columns)
}
}
}
-private[ops] abstract class RelationalGroups1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) {
+private[ops] abstract class RelationalGroups1Ops[K1, V](
+ self: TypedDataset[V],
+ g1: TypedColumn[V, K1]) {
protected def underlying: RelationalGroupsOps[V, ::[TypedColumn[V, K1], HNil], ::[K1, HNil], Tuple1[K1]]
private implicit def eg1 = g1.uencoder
@@ -52,117 +77,203 @@ private[ops] abstract class RelationalGroups1Ops[K1, V](self: TypedDataset[V], g
underlying.agg(c1)
}
- def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(Option[K1], U1, U2)] = {
+ def agg[U1, U2](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2]
+ ): TypedDataset[(Option[K1], U1, U2)] = {
implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder
underlying.agg(c1, c2)
}
- def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(Option[K1], U1, U2, U3)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder
+ def agg[U1, U2, U3](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3]
+ ): TypedDataset[(Option[K1], U1, U2, U3)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder
underlying.agg(c1, c2, c3)
}
- def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(Option[K1], U1, U2, U3, U4)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
+ def agg[U1, U2, U3, U4](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4]
+ ): TypedDataset[(Option[K1], U1, U2, U3, U4)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
underlying.agg(c1, c2, c3, c4)
}
- def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(Option[K1], U1, U2, U3, U4, U5)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder
+ def agg[U1, U2, U3, U4, U5](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4],
+ c5: TypedAggregate[V, U5]
+ ): TypedDataset[(Option[K1], U1, U2, U3, U4, U5)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder;
+ implicit val e5 = c5.uencoder
underlying.agg(c1, c2, c3, c4, c5)
}
- /** Methods on `TypedDataset[T]` that go through a full serialization and
- * deserialization of `T`, and execute outside of the Catalyst runtime.
- */
+ /**
+ * Methods on `TypedDataset[T]` that go through a full serialization and
+ * deserialization of `T`, and execute outside of the Catalyst runtime.
+ */
object deserialized {
- def mapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => U): TypedDataset[U] = {
+
+ def mapGroups[U: TypedEncoder](
+ f: (K1, Iterator[V]) => U
+ ): TypedDataset[U] = {
underlying.deserialized.mapGroups(AggregatingOps.tuple1(f))
}
- def flatMapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = {
+ def flatMapGroups[U: TypedEncoder](
+ f: (K1, Iterator[V]) => TraversableOnce[U]
+ ): TypedDataset[U] = {
underlying.deserialized.flatMapGroups(AggregatingOps.tuple1(f))
}
}
- def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): PivotNotValues[V, TypedColumn[V,K1] :: HNil, P] =
+ def pivot[P: CatalystPivotable](
+ pivotColumn: TypedColumn[V, P]
+ ): PivotNotValues[V, TypedColumn[V, K1] :: HNil, P] =
PivotNotValues(self, g1 :: HNil, pivotColumn)
}
-private[ops] abstract class RelationalGroups2Ops[K1, K2, V](self: TypedDataset[V], g1: TypedColumn[V, K1], g2: TypedColumn[V, K2]) {
+private[ops] abstract class RelationalGroups2Ops[K1, K2, V](
+ self: TypedDataset[V],
+ g1: TypedColumn[V, K1],
+ g2: TypedColumn[V, K2]) {
protected def underlying: RelationalGroupsOps[V, ::[TypedColumn[V, K1], ::[TypedColumn[V, K2], HNil]], ::[K1, ::[K2, HNil]], (K1, K2)]
private implicit def eg1 = g1.uencoder
private implicit def eg2 = g2.uencoder
- def agg[U1](c1: TypedAggregate[V, U1]): TypedDataset[(Option[K1], Option[K2], U1)] = {
+ def agg[U1](
+ c1: TypedAggregate[V, U1]
+ ): TypedDataset[(Option[K1], Option[K2], U1)] = {
implicit val e1 = c1.uencoder
underlying.agg(c1)
}
- def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(Option[K1], Option[K2], U1, U2)] = {
+ def agg[U1, U2](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2]
+ ): TypedDataset[(Option[K1], Option[K2], U1, U2)] = {
implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder
underlying.agg(c1, c2)
}
- def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(Option[K1], Option[K2], U1, U2, U3)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder
+ def agg[U1, U2, U3](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3]
+ ): TypedDataset[(Option[K1], Option[K2], U1, U2, U3)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder
underlying.agg(c1, c2, c3)
}
- def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
- underlying.agg(c1 , c2 , c3 , c4)
+ def agg[U1, U2, U3, U4](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4]
+ ): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder
+ underlying.agg(c1, c2, c3, c4)
}
- def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4, U5)] = {
- implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder
+ def agg[U1, U2, U3, U4, U5](
+ c1: TypedAggregate[V, U1],
+ c2: TypedAggregate[V, U2],
+ c3: TypedAggregate[V, U3],
+ c4: TypedAggregate[V, U4],
+ c5: TypedAggregate[V, U5]
+ ): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4, U5)] = {
+ implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder;
+ implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder;
+ implicit val e5 = c5.uencoder
underlying.agg(c1, c2, c3, c4, c5)
}
- /** Methods on `TypedDataset[T]` that go through a full serialization and
- * deserialization of `T`, and execute outside of the Catalyst runtime.
- */
+ /**
+ * Methods on `TypedDataset[T]` that go through a full serialization and
+ * deserialization of `T`, and execute outside of the Catalyst runtime.
+ */
object deserialized {
- def mapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => U): TypedDataset[U] = {
+
+ def mapGroups[U: TypedEncoder](
+ f: ((K1, K2), Iterator[V]) => U
+ ): TypedDataset[U] = {
underlying.deserialized.mapGroups(f)
}
- def flatMapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = {
+ def flatMapGroups[U: TypedEncoder](
+ f: ((K1, K2), Iterator[V]) => TraversableOnce[U]
+ ): TypedDataset[U] = {
underlying.deserialized.flatMapGroups(f)
}
}
- def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]):
- PivotNotValues[V, TypedColumn[V,K1] :: TypedColumn[V, K2] :: HNil, P] =
+ def pivot[P: CatalystPivotable](
+ pivotColumn: TypedColumn[V, P]
+ ): PivotNotValues[V, TypedColumn[V, K1] :: TypedColumn[V, K2] :: HNil, P] =
PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn)
}
-class RollupManyOps[T, TK <: HList, K <: HList, KT](self: TypedDataset[T], groupedBy: TK)
- (implicit
+class RollupManyOps[T, TK <: HList, K <: HList, KT](
+ self: TypedDataset[T],
+ groupedBy: TK
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: ToTraversable.Aux[TK, List, UntypedExpression[T]],
- i2: Tupler.Aux[K, KT]
- ) extends RelationalGroupsOps[T, TK, K, KT](self, groupedBy, (dataset, cols) => dataset.rollup(cols: _*))
+ i2: Tupler.Aux[K, KT])
+ extends RelationalGroupsOps[T, TK, K, KT](
+ self,
+ groupedBy,
+ (dataset, cols) => dataset.rollup(cols: _*)
+ )
-class Rollup1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) extends RelationalGroups1Ops(self, g1) {
+class Rollup1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1])
+ extends RelationalGroups1Ops(self, g1) {
override protected def underlying = new RollupManyOps(self, g1 :: HNil)
}
-class Rollup2Ops[K1, K2, V](self: TypedDataset[V], g1: TypedColumn[V, K1], g2: TypedColumn[V, K2]) extends RelationalGroups2Ops(self, g1, g2) {
+class Rollup2Ops[K1, K2, V](
+ self: TypedDataset[V],
+ g1: TypedColumn[V, K1],
+ g2: TypedColumn[V, K2])
+ extends RelationalGroups2Ops(self, g1, g2) {
override protected def underlying = new RollupManyOps(self, g1 :: g2 :: HNil)
}
-class CubeManyOps[T, TK <: HList, K <: HList, KT](self: TypedDataset[T], groupedBy: TK)
- (implicit
+class CubeManyOps[T, TK <: HList, K <: HList, KT](
+ self: TypedDataset[T],
+ groupedBy: TK
+ )(implicit
i0: ColumnTypes.Aux[T, TK, K],
i1: ToTraversable.Aux[TK, List, UntypedExpression[T]],
- i2: Tupler.Aux[K, KT]
- ) extends RelationalGroupsOps[T, TK, K, KT](self, groupedBy, (dataset, cols) => dataset.cube(cols: _*))
+ i2: Tupler.Aux[K, KT])
+ extends RelationalGroupsOps[T, TK, K, KT](
+ self,
+ groupedBy,
+ (dataset, cols) => dataset.cube(cols: _*)
+ )
-class Cube1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) extends RelationalGroups1Ops(self, g1) {
+class Cube1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1])
+ extends RelationalGroups1Ops(self, g1) {
override protected def underlying = new CubeManyOps(self, g1 :: HNil)
}
-class Cube2Ops[K1, K2, V](self: TypedDataset[V], g1: TypedColumn[V, K1], g2: TypedColumn[V, K2]) extends RelationalGroups2Ops(self, g1, g2) {
+class Cube2Ops[K1, K2, V](
+ self: TypedDataset[V],
+ g1: TypedColumn[V, K1],
+ g2: TypedColumn[V, K2])
+ extends RelationalGroups2Ops(self, g1, g2) {
override protected def underlying = new CubeManyOps(self, g1 :: g2 :: HNil)
}
diff --git a/dataset/src/main/scala/frameless/ops/Repeat.scala b/dataset/src/main/scala/frameless/ops/Repeat.scala
index bde855500..b9f4d7c94 100644
--- a/dataset/src/main/scala/frameless/ops/Repeat.scala
+++ b/dataset/src/main/scala/frameless/ops/Repeat.scala
@@ -1,33 +1,37 @@
package frameless
package ops
-import shapeless.{HList, Nat, Succ}
+import shapeless.{ HList, Nat, Succ }
import shapeless.ops.hlist.Prepend
-/** Typeclass supporting repeating L-typed HLists N times.
- *
- * Repeat[Int :: String :: HNil, Nat._2].Out =:=
- * Int :: String :: Int :: String :: HNil
- *
- * By Jeremy Smith. To be replaced by `shapeless.ops.hlists.Repeat`
- * once (https://github.com/milessabin/shapeless/pull/730 is published.
- */
+/**
+ * Typeclass supporting repeating L-typed HLists N times.
+ *
+ * Repeat[Int :: String :: HNil, Nat._2].Out =:=
+ * Int :: String :: Int :: String :: HNil
+ *
+ * By Jeremy Smith. To be replaced by `shapeless.ops.hlists.Repeat`
+ * once (https://github.com/milessabin/shapeless/pull/730 is published.
+ */
trait Repeat[L <: HList, N <: Nat] {
type Out <: HList
}
object Repeat {
- type Aux[L <: HList, N <: Nat, Out0 <: HList] = Repeat[L, N] { type Out = Out0 }
+
+ type Aux[L <: HList, N <: Nat, Out0 <: HList] = Repeat[L, N] {
+ type Out = Out0
+ }
implicit def base[L <: HList]: Aux[L, Nat._1, L] = new Repeat[L, Nat._1] {
type Out = L
}
- implicit def succ[L <: HList, Prev <: Nat, PrevOut <: HList, P <: HList]
- (implicit
+ implicit def succ[L <: HList, Prev <: Nat, PrevOut <: HList, P <: HList](
+ implicit
i0: Aux[L, Prev, PrevOut],
i1: Prepend.Aux[L, PrevOut, P]
): Aux[L, Succ[Prev], P] = new Repeat[L, Succ[Prev]] {
- type Out = P
- }
+ type Out = P
+ }
}
diff --git a/dataset/src/main/scala/frameless/ops/SmartProject.scala b/dataset/src/main/scala/frameless/ops/SmartProject.scala
index ec3628efd..7242975f4 100644
--- a/dataset/src/main/scala/frameless/ops/SmartProject.scala
+++ b/dataset/src/main/scala/frameless/ops/SmartProject.scala
@@ -2,38 +2,47 @@ package frameless
package ops
import shapeless.ops.hlist.ToTraversable
-import shapeless.ops.record.{Keys, SelectAll, Values}
-import shapeless.{HList, LabelledGeneric}
+import shapeless.ops.record.{ Keys, SelectAll, Values }
+import shapeless.{ HList, LabelledGeneric }
import scala.annotation.implicitNotFound
@implicitNotFound(msg = "Cannot prove that ${T} can be projected to ${U}. Perhaps not all member names and types of ${U} are the same in ${T}?")
-case class SmartProject[T: TypedEncoder, U: TypedEncoder](apply: TypedDataset[T] => TypedDataset[U])
+case class SmartProject[T: TypedEncoder, U: TypedEncoder](
+ apply: TypedDataset[T] => TypedDataset[U])
object SmartProject {
+
/**
- * Proofs that there is a type-safe projection from a type T to another type U. It requires that:
- * (a) both T and U are Products for which a LabelledGeneric can be derived (e.g., case classes),
- * (b) all members of U have a corresponding member in T that has both the same name and type.
- *
- * @param i0 the LabelledGeneric derived for T
- * @param i1 the LabelledGeneric derived for U
- * @param i2 the keys of U
- * @param i3 selects all the values from T using the keys of U
- * @param i4 selects all the values of LabeledGeneric[U]
- * @param i5 proof that U and the projection of T have the same type
- * @param i6 allows for traversing the keys of U
- * @tparam T the original type T
- * @tparam U the projected type U
- * @tparam TRec shapeless' Record representation of T
- * @tparam TProj the projection of T using the keys of U
- * @tparam URec shapeless' Record representation of U
- * @tparam UVals the values of U as an HList
- * @tparam UKeys the keys of U as an HList
- * @return a projection if it exists
- */
- implicit def deriveProduct[T: TypedEncoder, U: TypedEncoder, TRec <: HList, TProj <: HList, URec <: HList, UVals <: HList, UKeys <: HList]
- (implicit
+ * Proofs that there is a type-safe projection from a type T to another type U. It requires that:
+ * (a) both T and U are Products for which a LabelledGeneric can be derived (e.g., case classes),
+ * (b) all members of U have a corresponding member in T that has both the same name and type.
+ *
+ * @param i0 the LabelledGeneric derived for T
+ * @param i1 the LabelledGeneric derived for U
+ * @param i2 the keys of U
+ * @param i3 selects all the values from T using the keys of U
+ * @param i4 selects all the values of LabeledGeneric[U]
+ * @param i5 proof that U and the projection of T have the same type
+ * @param i6 allows for traversing the keys of U
+ * @tparam T the original type T
+ * @tparam U the projected type U
+ * @tparam TRec shapeless' Record representation of T
+ * @tparam TProj the projection of T using the keys of U
+ * @tparam URec shapeless' Record representation of U
+ * @tparam UVals the values of U as an HList
+ * @tparam UKeys the keys of U as an HList
+ * @return a projection if it exists
+ */
+ implicit def deriveProduct[
+ T: TypedEncoder,
+ U: TypedEncoder,
+ TRec <: HList,
+ TProj <: HList,
+ URec <: HList,
+ UVals <: HList,
+ UKeys <: HList
+ ](implicit
i0: LabelledGeneric.Aux[T, TRec],
i1: LabelledGeneric.Aux[U, URec],
i2: Keys.Aux[URec, UKeys],
@@ -41,8 +50,14 @@ object SmartProject {
i4: Values.Aux[URec, UVals],
i5: UVals =:= TProj,
i6: ToTraversable.Aux[UKeys, Seq, Symbol]
- ): SmartProject[T,U] = SmartProject[T, U]({ from =>
- val names = implicitly[Keys.Aux[URec, UKeys]].apply().to[Seq].map(_.name).map(from.dataset.col)
- TypedDataset.create(from.dataset.toDF().select(names: _*).as[U](TypedExpressionEncoder[U]))
- })
+ ): SmartProject[T, U] = SmartProject[T, U]({ from =>
+ val names = implicitly[Keys.Aux[URec, UKeys]]
+ .apply()
+ .to[Seq]
+ .map(_.name)
+ .map(from.dataset.col)
+ TypedDataset.create(
+ from.dataset.toDF().select(names: _*).as[U](TypedExpressionEncoder[U])
+ )
+ })
}
diff --git a/dataset/src/main/scala/frameless/syntax/package.scala b/dataset/src/main/scala/frameless/syntax/package.scala
index c6045981f..7050f93e8 100644
--- a/dataset/src/main/scala/frameless/syntax/package.scala
+++ b/dataset/src/main/scala/frameless/syntax/package.scala
@@ -1,5 +1,7 @@
package frameless
package object syntax extends FramelessSyntax {
- implicit val DefaultSparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob
+
+ implicit val DefaultSparkDelay: SparkDelay[Job] =
+ Job.framelessSparkDelayForJob
}
diff --git a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala b/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala
index 5459230d4..5cdb34155 100644
--- a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala
+++ b/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala
@@ -2,24 +2,34 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct}
-import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{ Alias, CreateStruct }
+import org.apache.spark.sql.catalyst.expressions.{ Expression, NamedExpression }
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, Project }
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ObjectType
import scala.reflect.ClassTag
object FramelessInternals {
- def objectTypeFor[A](implicit classTag: ClassTag[A]): ObjectType = ObjectType(classTag.runtimeClass)
+
+ def objectTypeFor[A](
+ implicit
+ classTag: ClassTag[A]
+ ): ObjectType = ObjectType(classTag.runtimeClass)
def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = {
- ds.toDF().queryExecution.analyzed.resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver).getOrElse {
- throw new AnalysisException(
- s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames.mkString(", ")})""")
- }
+ ds.toDF()
+ .queryExecution
+ .analyzed
+ .resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver)
+ .getOrElse {
+ throw new AnalysisException(
+ s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames
+ .mkString(", ")})"""
+ )
+ }
}
def expr(column: Column): Expression = column.expr
@@ -29,45 +39,72 @@ object FramelessInternals {
def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution =
ds.sparkSession.sessionState.executePlan(plan)
- def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = {
+ def joinPlan(
+ ds: Dataset[_],
+ plan: LogicalPlan,
+ leftPlan: LogicalPlan,
+ rightPlan: LogicalPlan
+ ): LogicalPlan = {
val joined = executePlan(ds, plan)
val leftOutput = joined.analyzed.output.take(leftPlan.output.length)
val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length)
- Project(List(
- Alias(CreateStruct(leftOutput), "_1")(),
- Alias(CreateStruct(rightOutput), "_2")()
- ), joined.analyzed)
+ Project(
+ List(
+ Alias(CreateStruct(leftOutput), "_1")(),
+ Alias(CreateStruct(rightOutput), "_2")()
+ ),
+ joined.analyzed
+ )
}
- def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] =
+ def mkDataset[T](
+ sqlContext: SQLContext,
+ plan: LogicalPlan,
+ encoder: Encoder[T]
+ ): Dataset[T] =
new Dataset(sqlContext, plan, encoder)
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
Dataset.ofRows(sparkSession, logicalPlan)
// because org.apache.spark.sql.types.UserDefinedType is private[spark]
- type UserDefinedType[A >: Null] = org.apache.spark.sql.types.UserDefinedType[A]
+ type UserDefinedType[A >: Null] =
+ org.apache.spark.sql.types.UserDefinedType[A]
// below only tested in SelfJoinTests.colLeft and colRight are equivalent to col outside of joins
// - via files (codegen) forces doGenCode eval.
/** Expression to tag columns from the left hand side of join expression. */
- case class DisambiguateLeft[T](tagged: Expression) extends Expression with NonSQLExpression {
+ case class DisambiguateLeft[T](tagged: Expression)
+ extends Expression
+ with NonSQLExpression {
def eval(input: InternalRow): Any = tagged.eval(input)
def nullable: Boolean = false
def children: Seq[Expression] = tagged :: Nil
def dataType: DataType = tagged.dataType
- protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = tagged.genCode(ctx)
- protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head)
+
+ protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ tagged.genCode(ctx)
+
+ protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]
+ ): Expression = copy(newChildren.head)
}
/** Expression to tag columns from the right hand side of join expression. */
- case class DisambiguateRight[T](tagged: Expression) extends Expression with NonSQLExpression {
+ case class DisambiguateRight[T](tagged: Expression)
+ extends Expression
+ with NonSQLExpression {
def eval(input: InternalRow): Any = tagged.eval(input)
def nullable: Boolean = false
def children: Seq[Expression] = tagged :: Nil
def dataType: DataType = tagged.dataType
- protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = tagged.genCode(ctx)
- protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head)
+
+ protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
+ tagged.genCode(ctx)
+
+ protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]
+ ): Expression = copy(newChildren.head)
}
}
diff --git a/dataset/src/main/spark-3.4+/frameless/MapGroups.scala b/dataset/src/main/spark-3.4+/frameless/MapGroups.scala
index 6856acba4..25411420b 100644
--- a/dataset/src/main/spark-3.4+/frameless/MapGroups.scala
+++ b/dataset/src/main/spark-3.4+/frameless/MapGroups.scala
@@ -2,15 +2,19 @@ package frameless
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MapGroups => SMapGroups}
+import org.apache.spark.sql.catalyst.plans.logical.{
+ LogicalPlan,
+ MapGroups => SMapGroups
+}
object MapGroups {
+
def apply[K: Encoder, T: Encoder, U: Encoder](
- func: (K, Iterator[T]) => TraversableOnce[U],
- groupingAttributes: Seq[Attribute],
- dataAttributes: Seq[Attribute],
- child: LogicalPlan
- ): LogicalPlan =
+ func: (K, Iterator[T]) => TraversableOnce[U],
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ child: LogicalPlan
+ ): LogicalPlan =
SMapGroups(
func,
groupingAttributes,
diff --git a/dataset/src/main/spark-3/frameless/MapGroups.scala b/dataset/src/main/spark-3/frameless/MapGroups.scala
index 3fd27f333..67ec8b731 100644
--- a/dataset/src/main/spark-3/frameless/MapGroups.scala
+++ b/dataset/src/main/spark-3/frameless/MapGroups.scala
@@ -2,13 +2,17 @@ package frameless
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MapGroups => SMapGroups}
+import org.apache.spark.sql.catalyst.plans.logical.{
+ LogicalPlan,
+ MapGroups => SMapGroups
+}
object MapGroups {
+
def apply[K: Encoder, T: Encoder, U: Encoder](
- func: (K, Iterator[T]) => TraversableOnce[U],
- groupingAttributes: Seq[Attribute],
- dataAttributes: Seq[Attribute],
- child: LogicalPlan
- ): LogicalPlan = SMapGroups(func, groupingAttributes, dataAttributes, child)
+ func: (K, Iterator[T]) => TraversableOnce[U],
+ groupingAttributes: Seq[Attribute],
+ dataAttributes: Seq[Attribute],
+ child: LogicalPlan
+ ): LogicalPlan = SMapGroups(func, groupingAttributes, dataAttributes, child)
}
diff --git a/dataset/src/test/scala/frameless/AsTests.scala b/dataset/src/test/scala/frameless/AsTests.scala
index c1091f9ca..5dd19a6ce 100644
--- a/dataset/src/test/scala/frameless/AsTests.scala
+++ b/dataset/src/test/scala/frameless/AsTests.scala
@@ -5,14 +5,15 @@ import org.scalacheck.Prop._
class AsTests extends TypedDatasetSuite {
test("as[X2[A, B]]") {
- def prop[A, B](data: Vector[(A, B)])(
- implicit
- eab: TypedEncoder[(A, B)],
- ex2: TypedEncoder[X2[A, B]]
- ): Prop = {
+ def prop[A, B](
+ data: Vector[(A, B)]
+ )(implicit
+ eab: TypedEncoder[(A, B)],
+ ex2: TypedEncoder[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
- val dataset2 = dataset.as[X2[A,B]]().collect().run().toVector
+ val dataset2 = dataset.as[X2[A, B]]().collect().run().toVector
val data2 = data.map { case (a, b) => X2(a, b) }
dataset2 ?= data2
@@ -27,17 +28,16 @@ class AsTests extends TypedDatasetSuite {
}
test("as[X2[X2[A, B], C]") {
- def prop[A, B, C](data: Vector[(A, B, C)])(
- implicit
- eab: TypedEncoder[((A, B), C)],
- ex2: TypedEncoder[X2[X2[A, B], C]]
- ): Prop = {
- val data2 = data.map {
- case (a, b, c) => ((a, b), c)
- }
+ def prop[A, B, C](
+ data: Vector[(A, B, C)]
+ )(implicit
+ eab: TypedEncoder[((A, B), C)],
+ ex2: TypedEncoder[X2[X2[A, B], C]]
+ ): Prop = {
+ val data2 = data.map { case (a, b, c) => ((a, b), c) }
val dataset = TypedDataset.create(data2)
- val dataset2 = dataset.as[X2[X2[A,B], C]]().collect().run().toVector
+ val dataset2 = dataset.as[X2[X2[A, B], C]]().collect().run().toVector
val data3 = data2.map { case ((a, b), c) => X2(X2(a, b), c) }
dataset2 ?= data3
@@ -47,7 +47,13 @@ class AsTests extends TypedDatasetSuite {
check(forAll(prop[String, Int, String] _))
check(forAll(prop[String, String, Int] _))
check(forAll(prop[Long, Int, String] _))
- check(forAll(prop[Seq[Seq[Option[Seq[Long]]]], Seq[Int], Option[Seq[Option[Int]]]] _))
- check(forAll(prop[Seq[Option[Seq[String]]], Seq[Int], Seq[Option[String]]] _))
+ check(
+ forAll(
+ prop[Seq[Seq[Option[Seq[Long]]]], Seq[Int], Option[Seq[Option[Int]]]] _
+ )
+ )
+ check(
+ forAll(prop[Seq[Option[Seq[String]]], Seq[Int], Seq[Option[String]]] _)
+ )
}
}
diff --git a/dataset/src/test/scala/frameless/BitwiseTests.scala b/dataset/src/test/scala/frameless/BitwiseTests.scala
index f58c906a2..0d1914de0 100644
--- a/dataset/src/test/scala/frameless/BitwiseTests.scala
+++ b/dataset/src/test/scala/frameless/BitwiseTests.scala
@@ -7,12 +7,12 @@ import org.scalatest.matchers.should.Matchers
class BitwiseTests extends TypedDatasetSuite with Matchers {
/**
- * providing instances with implementations for bitwise operations since in the tests
- * we need to check the results from frameless vs the results from normal scala operators
- * for Numeric it is easy to test since scala comes with Numeric typeclass but there seems
- * to be no equivalent typeclass for bitwise ops for Byte Short Int and Long types supported in Catalyst
- */
- trait CatalystBitwise4Tests[A]{
+ * providing instances with implementations for bitwise operations since in the tests
+ * we need to check the results from frameless vs the results from normal scala operators
+ * for Numeric it is easy to test since scala comes with Numeric typeclass but there seems
+ * to be no equivalent typeclass for bitwise ops for Byte Short Int and Long types supported in Catalyst
+ */
+ trait CatalystBitwise4Tests[A] {
def bitwiseAnd(a1: A, a2: A): A
def bitwiseOr(a1: A, a2: A): A
def bitwiseXor(a1: A, a2: A): A
@@ -22,33 +22,44 @@ class BitwiseTests extends TypedDatasetSuite with Matchers {
}
object CatalystBitwise4Tests {
- implicit val framelessbyteBitwise : CatalystBitwise4Tests[Byte] = new CatalystBitwise4Tests[Byte] {
- def bitwiseOr(a1: Byte, a2: Byte) : Byte = (a1 | a2).toByte
- def bitwiseAnd(a1: Byte, a2: Byte): Byte = (a1 & a2).toByte
- def bitwiseXor(a1: Byte, a2: Byte): Byte = (a1 ^ a2).toByte
- }
- implicit val framelessshortBitwise : CatalystBitwise4Tests[Short] = new CatalystBitwise4Tests[Short] {
- def bitwiseOr(a1: Short, a2: Short) : Short = (a1 | a2).toShort
- def bitwiseAnd(a1: Short, a2: Short): Short = (a1 & a2).toShort
- def bitwiseXor(a1: Short, a2: Short): Short = (a1 ^ a2).toShort
- }
- implicit val framelessintBitwise : CatalystBitwise4Tests[Int] = new CatalystBitwise4Tests[Int] {
- def bitwiseOr(a1: Int, a2: Int) : Int = a1 | a2
- def bitwiseAnd(a1: Int, a2: Int): Int = a1 & a2
- def bitwiseXor(a1: Int, a2: Int): Int = a1 ^ a2
- }
- implicit val framelesslongBitwise : CatalystBitwise4Tests[Long] = new CatalystBitwise4Tests[Long] {
- def bitwiseOr(a1: Long, a2: Long) : Long = a1 | a2
- def bitwiseAnd(a1: Long, a2: Long): Long = a1 & a2
- def bitwiseXor(a1: Long, a2: Long): Long = a1 ^ a2
- }
+
+ implicit val framelessbyteBitwise: CatalystBitwise4Tests[Byte] =
+ new CatalystBitwise4Tests[Byte] {
+ def bitwiseOr(a1: Byte, a2: Byte): Byte = (a1 | a2).toByte
+ def bitwiseAnd(a1: Byte, a2: Byte): Byte = (a1 & a2).toByte
+ def bitwiseXor(a1: Byte, a2: Byte): Byte = (a1 ^ a2).toByte
+ }
+
+ implicit val framelessshortBitwise: CatalystBitwise4Tests[Short] =
+ new CatalystBitwise4Tests[Short] {
+ def bitwiseOr(a1: Short, a2: Short): Short = (a1 | a2).toShort
+ def bitwiseAnd(a1: Short, a2: Short): Short = (a1 & a2).toShort
+ def bitwiseXor(a1: Short, a2: Short): Short = (a1 ^ a2).toShort
+ }
+
+ implicit val framelessintBitwise: CatalystBitwise4Tests[Int] =
+ new CatalystBitwise4Tests[Int] {
+ def bitwiseOr(a1: Int, a2: Int): Int = a1 | a2
+ def bitwiseAnd(a1: Int, a2: Int): Int = a1 & a2
+ def bitwiseXor(a1: Int, a2: Int): Int = a1 ^ a2
+ }
+
+ implicit val framelesslongBitwise: CatalystBitwise4Tests[Long] =
+ new CatalystBitwise4Tests[Long] {
+ def bitwiseOr(a1: Long, a2: Long): Long = a1 | a2
+ def bitwiseAnd(a1: Long, a2: Long): Long = a1 & a2
+ def bitwiseXor(a1: Long, a2: Long): Long = a1 ^ a2
+ }
}
import CatalystBitwise4Tests._
test("bitwiseAND") {
- def prop[A: TypedEncoder: CatalystBitwise](a: A, b: A)(
- implicit catalystBitwise4Tests: CatalystBitwise4Tests[A]
- ): Prop = {
+ def prop[A: TypedEncoder: CatalystBitwise](
+ a: A,
+ b: A
+ )(implicit
+ catalystBitwise4Tests: CatalystBitwise4Tests[A]
+ ): Prop = {
val df = TypedDataset.create(X2(a, b) :: Nil)
val result = implicitly[CatalystBitwise4Tests[A]].bitwiseAnd(a, b)
val resultSymbolic = implicitly[CatalystBitwise4Tests[A]].&(a, b)
@@ -56,7 +67,9 @@ class BitwiseTests extends TypedDatasetSuite with Matchers {
val gotSymbolic = df.select(df.col('a) & b).collect().run()
val symbolicCol2Col = df.select(df.col('a) & df.col('b)).collect().run()
val canCast = df.select(df.col('a).cast[Long] & 0L).collect().run()
- canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(0L)
+ canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(
+ 0L
+ )
result ?= resultSymbolic
symbolicCol2Col ?= (result :: Nil)
got ?= (result :: Nil)
@@ -70,9 +83,12 @@ class BitwiseTests extends TypedDatasetSuite with Matchers {
}
test("bitwiseOR") {
- def prop[A: TypedEncoder: CatalystBitwise](a: A, b: A)(
- implicit catalystBitwise4Tests: CatalystBitwise4Tests[A]
- ): Prop = {
+ def prop[A: TypedEncoder: CatalystBitwise](
+ a: A,
+ b: A
+ )(implicit
+ catalystBitwise4Tests: CatalystBitwise4Tests[A]
+ ): Prop = {
val df = TypedDataset.create(X2(a, b) :: Nil)
val result = implicitly[CatalystBitwise4Tests[A]].bitwiseOr(a, b)
val resultSymbolic = implicitly[CatalystBitwise4Tests[A]].|(a, b)
@@ -80,7 +96,9 @@ class BitwiseTests extends TypedDatasetSuite with Matchers {
val gotSymbolic = df.select(df.col('a) | b).collect().run()
val symbolicCol2Col = df.select(df.col('a) | df.col('b)).collect().run()
val canCast = df.select(df.col('a).cast[Long] | -1L).collect().run()
- canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(-1L)
+ canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(
+ -1L
+ )
result ?= resultSymbolic
symbolicCol2Col ?= (result :: Nil)
got ?= (result :: Nil)
@@ -94,9 +112,12 @@ class BitwiseTests extends TypedDatasetSuite with Matchers {
}
test("bitwiseXOR") {
- def prop[A: TypedEncoder: CatalystBitwise](a: A, b: A)(
- implicit catalystBitwise4Tests: CatalystBitwise4Tests[A]
- ): Prop = {
+ def prop[A: TypedEncoder: CatalystBitwise](
+ a: A,
+ b: A
+ )(implicit
+ catalystBitwise4Tests: CatalystBitwise4Tests[A]
+ ): Prop = {
val df = TypedDataset.create(X2(a, b) :: Nil)
val result = implicitly[CatalystBitwise4Tests[A]].bitwiseXor(a, b)
val resultSymbolic = implicitly[CatalystBitwise4Tests[A]].^(a, b)
@@ -104,7 +125,9 @@ class BitwiseTests extends TypedDatasetSuite with Matchers {
val got = df.select(df.col('a) bitwiseXOR df.col('b)).collect().run()
val gotSymbolic = df.select(df.col('a) ^ b).collect().run()
val zeroes = df.select(df.col('a) ^ df.col('a)).collect().run()
- zeroes should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(0L)
+ zeroes should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(
+ 0L
+ )
got ?= (result :: Nil)
gotSymbolic ?= (resultSymbolic :: Nil)
}
diff --git a/dataset/src/test/scala/frameless/CastTests.scala b/dataset/src/test/scala/frameless/CastTests.scala
index 5f79f8fa6..00328d16c 100644
--- a/dataset/src/test/scala/frameless/CastTests.scala
+++ b/dataset/src/test/scala/frameless/CastTests.scala
@@ -1,14 +1,16 @@
package frameless
-import org.scalacheck.{Arbitrary, Gen, Prop}
+import org.scalacheck.{ Arbitrary, Gen, Prop }
import org.scalacheck.Prop._
class CastTests extends TypedDatasetSuite {
- def prop[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A)(
- implicit
- cast: CatalystCast[A, B]
- ): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ f: A => B
+ )(a: A
+ )(implicit
+ cast: CatalystCast[A, B]
+ ): Prop = {
val df = TypedDataset.create(X1(a) :: Nil)
val got = df.select(df.col('a).cast[B]).collect().run()
@@ -100,9 +102,11 @@ class CastTests extends TypedDatasetSuite {
check(prop[Short, Boolean](_ != 0) _)
// booleanToNumeric
- check(prop[Boolean, BigDecimal](x => if (x) BigDecimal(1) else BigDecimal(0)) _)
+ check(
+ prop[Boolean, BigDecimal](x => if (x) BigDecimal(1) else BigDecimal(0)) _
+ )
check(prop[Boolean, Byte](x => if (x) 1 else 0) _)
- check(prop[Boolean, Double](x => if (x) 1.0f else 0.0f) _)
+ check(prop[Boolean, Double](x => if (x) 1.0F else 0.0F) _)
check(prop[Boolean, Int](x => if (x) 1 else 0) _)
check(prop[Boolean, Long](x => if (x) 1L else 0L) _)
check(prop[Boolean, Short](x => if (x) 1 else 0) _)
diff --git a/dataset/src/test/scala/frameless/ColTests.scala b/dataset/src/test/scala/frameless/ColTests.scala
index ad62aa068..d71174ce7 100644
--- a/dataset/src/test/scala/frameless/ColTests.scala
+++ b/dataset/src/test/scala/frameless/ColTests.scala
@@ -16,7 +16,10 @@ class ColTests extends TypedDatasetSuite {
x4.col[Int]('a)
t4.col[Int]('_1)
- illTyped("x4.col[String]('a)", "No column .* of type String in frameless.X4.*")
+ illTyped(
+ "x4.col[String]('a)",
+ "No column .* of type String in frameless.X4.*"
+ )
x4.col('b)
t4.col('_2)
diff --git a/dataset/src/test/scala/frameless/CollectTests.scala b/dataset/src/test/scala/frameless/CollectTests.scala
index 0ff1e6956..56f661961 100644
--- a/dataset/src/test/scala/frameless/CollectTests.scala
+++ b/dataset/src/test/scala/frameless/CollectTests.scala
@@ -85,10 +85,18 @@ class CollectTests extends TypedDatasetSuite {
object CollectTests {
import frameless.syntax._
- def prop[A: TypedEncoder : ClassTag](data: Vector[A])(implicit c: SparkSession): Prop =
+ def prop[A: TypedEncoder: ClassTag](
+ data: Vector[A]
+ )(implicit
+ c: SparkSession
+ ): Prop =
TypedDataset.create(data).collect().run().toVector ?= data
- def propArray[A: TypedEncoder : ClassTag](data: Vector[X1[Array[A]]])(implicit c: SparkSession): Prop =
+ def propArray[A: TypedEncoder: ClassTag](
+ data: Vector[X1[Array[A]]]
+ )(implicit
+ c: SparkSession
+ ): Prop =
Prop(TypedDataset.create(data).collect().run().toVector.zip(data).forall {
case (X1(l), X1(r)) => l.sameElements(r)
})
diff --git a/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala b/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala
index 0a9c532a6..ec3f66ac7 100644
--- a/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala
+++ b/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala
@@ -11,9 +11,12 @@ case class MyClass4(h: Boolean)
final class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers {
def ds = {
- TypedDataset.create(Seq(
- MyClass1(1, "2", MyClass2(3L, MyClass3(7.0D)), Some(MyClass4(true))),
- MyClass1(4, "5", MyClass2(6L, MyClass3(8.0D)), None)))
+ TypedDataset.create(
+ Seq(
+ MyClass1(1, "2", MyClass2(3L, MyClass3(7.0D)), Some(MyClass4(true))),
+ MyClass1(4, "5", MyClass2(6L, MyClass3(8.0D)), None)
+ )
+ )
}
test("col(_.a)") {
diff --git a/dataset/src/test/scala/frameless/CreateTests.scala b/dataset/src/test/scala/frameless/CreateTests.scala
index 4d9b5547d..aebb30b7b 100644
--- a/dataset/src/test/scala/frameless/CreateTests.scala
+++ b/dataset/src/test/scala/frameless/CreateTests.scala
@@ -1,6 +1,6 @@
package frameless
-import org.scalacheck.{Arbitrary, Prop}
+import org.scalacheck.{ Arbitrary, Prop }
import org.scalacheck.Prop._
import scala.reflect.ClassTag
@@ -13,29 +13,40 @@ class CreateTests extends TypedDatasetSuite with Matchers {
test("creation using X4 derived DataFrames") {
def prop[
- A: TypedEncoder,
- B: TypedEncoder,
- C: TypedEncoder,
- D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = {
+ A: TypedEncoder,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ D: TypedEncoder
+ ](data: Vector[X4[A, B, C, D]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- TypedDataset.createUnsafe[X4[A, B, C, D]](ds.toDF()).collect().run() ?= data
+ TypedDataset
+ .createUnsafe[X4[A, B, C, D]](ds.toDF())
+ .collect()
+ .run() ?= data
}
check(forAll(prop[Int, Char, X2[Option[Country], Country], Int] _))
check(forAll(prop[X2[Int, Int], Int, Boolean, Vector[Food]] _))
check(forAll(prop[String, Food, X3[Food, Country, Boolean], Int] _))
check(forAll(prop[String, Food, X3U[Food, Country, Boolean], Int] _))
- check(forAll(prop[
- Option[Vector[Food]],
- Vector[Vector[X2[Vector[(Person, X1[Char])], Country]]],
- X3[Food, Country, String],
- Vector[(Food, Country)]] _))
+ check(
+ forAll(
+ prop[Option[Vector[Food]], Vector[
+ Vector[X2[Vector[(Person, X1[Char])], Country]]
+ ], X3[Food, Country, String], Vector[(Food, Country)]] _
+ )
+ )
}
test("array fields") {
def prop[T: Arbitrary: TypedEncoder: ClassTag] = forAll {
- (d1: Array[T], d2: Array[Option[T]], d3: Array[X1[T]], d4: Array[X1[Option[T]]],
- d5: X1[Array[T]]) =>
+ (d1: Array[T],
+ d2: Array[Option[T]],
+ d3: Array[X1[T]],
+ d4: Array[X1[Option[T]]],
+ d5: X1[Array[T]]
+ ) =>
TypedDataset.create(Seq(d1)).collect().run().head.sameElements(d1) &&
TypedDataset.create(Seq(d2)).collect().run().head.sameElements(d2) &&
TypedDataset.create(Seq(d3)).collect().run().head.sameElements(d3) &&
@@ -55,13 +66,17 @@ class CreateTests extends TypedDatasetSuite with Matchers {
test("vector fields") {
def prop[T: Arbitrary: TypedEncoder] = forAll {
- (d1: Vector[T], d2: Vector[Option[T]], d3: Vector[X1[T]], d4: Vector[X1[Option[T]]],
- d5: X1[Vector[T]]) =>
- (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) &&
- (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) &&
- (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) &&
- (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) &&
- (TypedDataset.create(Seq(d5)).collect().run().head ?= d5)
+ (d1: Vector[T],
+ d2: Vector[Option[T]],
+ d3: Vector[X1[T]],
+ d4: Vector[X1[Option[T]]],
+ d5: X1[Vector[T]]
+ ) =>
+ (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) &&
+ (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) &&
+ (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) &&
+ (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) &&
+ (TypedDataset.create(Seq(d5)).collect().run().head ?= d5)
}
check(prop[Boolean])
@@ -77,9 +92,13 @@ class CreateTests extends TypedDatasetSuite with Matchers {
test("list fields") {
def prop[T: Arbitrary: TypedEncoder] = forAll {
- (d1: List[T], d2: List[Option[T]], d3: List[X1[T]], d4: List[X1[Option[T]]],
- d5: X1[List[T]]) =>
- (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) &&
+ (d1: List[T],
+ d2: List[Option[T]],
+ d3: List[X1[T]],
+ d4: List[X1[Option[T]]],
+ d5: X1[List[T]]
+ ) =>
+ (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) &&
(TypedDataset.create(Seq(d2)).collect().run().head ?= d2) &&
(TypedDataset.create(Seq(d3)).collect().run().head ?= d3) &&
(TypedDataset.create(Seq(d4)).collect().run().head ?= d4) &&
@@ -98,16 +117,23 @@ class CreateTests extends TypedDatasetSuite with Matchers {
}
test("Map fields (scala.Predef.Map / scala.collection.immutable.Map)") {
- def prop[A: Arbitrary: NotCatalystNullable: TypedEncoder, B: Arbitrary: NotCatalystNullable: TypedEncoder] = forAll {
- (d1: Map[A, B], d2: Map[B, A], d3: Map[A, Option[B]],
- d4: Map[A, X1[B]], d5: Map[X1[A], B], d6: Map[X1[A], X1[B]]) =>
-
- (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) &&
- (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) &&
- (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) &&
- (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) &&
- (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) &&
- (TypedDataset.create(Seq(d6)).collect().run().head ?= d6)
+ def prop[
+ A: Arbitrary: NotCatalystNullable: TypedEncoder,
+ B: Arbitrary: NotCatalystNullable: TypedEncoder
+ ] = forAll {
+ (d1: Map[A, B],
+ d2: Map[B, A],
+ d3: Map[A, Option[B]],
+ d4: Map[A, X1[B]],
+ d5: Map[X1[A], B],
+ d6: Map[X1[A], X1[B]]
+ ) =>
+ (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) &&
+ (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) &&
+ (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) &&
+ (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) &&
+ (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) &&
+ (TypedDataset.create(Seq(d6)).collect().run().head ?= d6)
}
check(prop[String, String])
@@ -123,14 +149,17 @@ class CreateTests extends TypedDatasetSuite with Matchers {
test("maps with Option keys should not resolve the TypedEncoder") {
val data: Seq[Map[Option[Int], Int]] = Seq(Map(Some(5) -> 5))
- illTyped("TypedDataset.create(data)", ".*could not find implicit value for parameter encoder.*")
+ illTyped(
+ "TypedDataset.create(data)",
+ ".*could not find implicit value for parameter encoder.*"
+ )
}
test("not aligned columns should throw an exception") {
- val v = Vector(X2(1,2))
+ val v = Vector(X2(1, 2))
val df = TypedDataset.create(v).dataset.toDF()
- a [IllegalStateException] should be thrownBy {
+ a[IllegalStateException] should be thrownBy {
TypedDataset.createUnsafe[X1[Int]](df).show().run()
}
}
@@ -139,13 +168,18 @@ class CreateTests extends TypedDatasetSuite with Matchers {
// e.g. when loading data from partitioned dataset
// the partition columns get appended to the end of the underlying relation
def prop[A: Arbitrary: TypedEncoder, B: Arbitrary: TypedEncoder] = forAll {
- (a1: A, b1: B) => {
- val ds = TypedDataset.create(
- Vector((b1, a1))
- ).dataset.toDF("b", "a").as[X2[A, B]](TypedExpressionEncoder[X2[A, B]])
- TypedDataset.create(ds).collect().run().head ?= X2(a1, b1)
-
- }
+ (a1: A, b1: B) =>
+ {
+ val ds = TypedDataset
+ .create(
+ Vector((b1, a1))
+ )
+ .dataset
+ .toDF("b", "a")
+ .as[X2[A, B]](TypedExpressionEncoder[X2[A, B]])
+ TypedDataset.create(ds).collect().run().head ?= X2(a1, b1)
+
+ }
}
check(prop[X1[Double], X1[X1[SQLDate]]])
check(prop[String, Int])
diff --git a/dataset/src/test/scala/frameless/DropTest.scala b/dataset/src/test/scala/frameless/DropTest.scala
index 3e5a0d739..affd0d8d3 100644
--- a/dataset/src/test/scala/frameless/DropTest.scala
+++ b/dataset/src/test/scala/frameless/DropTest.scala
@@ -8,28 +8,36 @@ class DropTest extends TypedDatasetSuite {
import DropTest._
test("fail to compile on missing value") {
- val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil)
+ val f: TypedDataset[X] = TypedDataset.create(
+ X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil
+ )
illTyped {
"""val fNew: TypedDataset[XMissing] = f.drop[XMissing]('j)"""
}
}
test("fail to compile on different column name") {
- val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil)
+ val f: TypedDataset[X] = TypedDataset.create(
+ X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil
+ )
illTyped {
"""val fNew: TypedDataset[XDifferentColumnName] = f.drop[XDifferentColumnName]('j)"""
}
}
test("fail to compile on added column name") {
- val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil)
+ val f: TypedDataset[X] = TypedDataset.create(
+ X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil
+ )
illTyped {
"""val fNew: TypedDataset[XAdded] = f.drop[XAdded]('j)"""
}
}
test("remove column in the middle") {
- val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil)
+ val f: TypedDataset[X] = TypedDataset.create(
+ X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil
+ )
val fNew: TypedDataset[XGood] = f.drop[XGood]
fNew.collect().run().foreach(xg => assert(xg === XGood(1, false)))
diff --git a/dataset/src/test/scala/frameless/DropTupledTest.scala b/dataset/src/test/scala/frameless/DropTupledTest.scala
index ff0158b91..d23b8a640 100644
--- a/dataset/src/test/scala/frameless/DropTupledTest.scala
+++ b/dataset/src/test/scala/frameless/DropTupledTest.scala
@@ -7,9 +7,9 @@ class DropTupledTest extends TypedDatasetSuite {
test("drop five columns") {
def prop[A: TypedEncoder](value: A): Prop = {
val d5 = TypedDataset.create(X5(value, value, value, value, value) :: Nil)
- val d4 = d5.dropTupled('a) //drops first column
- val d3 = d4.dropTupled('_4) //drops last column
- val d2 = d3.dropTupled('_2) //drops middle column
+ val d4 = d5.dropTupled('a) // drops first column
+ val d3 = d4.dropTupled('_4) // drops last column
+ val d2 = d3.dropTupled('_2) // drops middle column
val d1 = d2.dropTupled('_2)
Tuple1(value) ?= d1.collect().run().head
diff --git a/dataset/src/test/scala/frameless/ExplodeTests.scala b/dataset/src/test/scala/frameless/ExplodeTests.scala
index 3078ceb12..090f95522 100644
--- a/dataset/src/test/scala/frameless/ExplodeTests.scala
+++ b/dataset/src/test/scala/frameless/ExplodeTests.scala
@@ -1,7 +1,7 @@
package frameless
import frameless.functions.CatalystExplodableCollection
-import org.scalacheck.{Arbitrary, Prop}
+import org.scalacheck.{ Arbitrary, Prop }
import org.scalacheck.Prop.forAll
import org.scalacheck.Prop._
@@ -9,12 +9,19 @@ import scala.reflect.ClassTag
class ExplodeTests extends TypedDatasetSuite {
test("simple explode test") {
- val ds = TypedDataset.create(Seq((1,Array(1,2))))
- ds.explode('_2): TypedDataset[(Int,Int)]
+ val ds = TypedDataset.create(Seq((1, Array(1, 2))))
+ ds.explode('_2): TypedDataset[(Int, Int)]
}
test("explode on vectors/list/seq") {
- def prop[F[X] <: Traversable[X] : CatalystExplodableCollection, A: TypedEncoder](xs: List[X1[F[A]]])(implicit arb: Arbitrary[F[A]], enc: TypedEncoder[F[A]]): Prop = {
+ def prop[
+ F[X] <: Traversable[X]: CatalystExplodableCollection,
+ A: TypedEncoder
+ ](xs: List[X1[F[A]]]
+ )(implicit
+ arb: Arbitrary[F[A]],
+ enc: TypedEncoder[F[A]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
val framelessResults = tds.explode('a).collect().run().toVector
@@ -49,11 +56,14 @@ class ExplodeTests extends TypedDatasetSuite {
}
test("explode on maps") {
- def prop[A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X1[Map[A, B]]]): Prop = {
+ def prop[A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](
+ xs: List[X1[Map[A, B]]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
val framelessResults = tds.explodeMap('a).collect().run().toVector
- val scalaResults = xs.flatMap(_.a.toList).map(t => Tuple1(Tuple2(t._1, t._2))).toVector
+ val scalaResults =
+ xs.flatMap(_.a.toList).map(t => Tuple1(Tuple2(t._1, t._2))).toVector
framelessResults ?= scalaResults
}
@@ -64,11 +74,18 @@ class ExplodeTests extends TypedDatasetSuite {
}
test("explode on maps preserving other columns") {
- def prop[K: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X2[K, Map[A, B]]]): Prop = {
+ def prop[
+ K: TypedEncoder: ClassTag,
+ A: TypedEncoder: ClassTag,
+ B: TypedEncoder: ClassTag
+ ](xs: List[X2[K, Map[A, B]]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
val framelessResults = tds.explodeMap('b).collect().run().toVector
- val scalaResults = xs.flatMap { x2 => x2.b.toList.map((x2.a, _)) }.toVector
+ val scalaResults = xs.flatMap { x2 =>
+ x2.b.toList.map((x2.a, _))
+ }.toVector
framelessResults ?= scalaResults
}
@@ -79,11 +96,19 @@ class ExplodeTests extends TypedDatasetSuite {
}
test("explode on maps making sure no key / value naming collision happens") {
- def prop[K: TypedEncoder: ClassTag, V: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X3KV[K, V, Map[A, B]]]): Prop = {
+ def prop[
+ K: TypedEncoder: ClassTag,
+ V: TypedEncoder: ClassTag,
+ A: TypedEncoder: ClassTag,
+ B: TypedEncoder: ClassTag
+ ](xs: List[X3KV[K, V, Map[A, B]]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
val framelessResults = tds.explodeMap('c).collect().run().toVector
- val scalaResults = xs.flatMap { x3 => x3.c.toList.map((x3.key, x3.value, _)) }.toVector
+ val scalaResults = xs.flatMap { x3 =>
+ x3.c.toList.map((x3.key, x3.value, _))
+ }.toVector
framelessResults ?= scalaResults
}
diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala
index 56d5d2ec5..c93660ca9 100644
--- a/dataset/src/test/scala/frameless/FilterTests.scala
+++ b/dataset/src/test/scala/frameless/FilterTests.scala
@@ -7,7 +7,12 @@ import org.scalacheck.Prop._
final class FilterTests extends TypedDatasetSuite with Matchers {
test("filter('a == lit(b))") {
- def prop[A: TypedEncoder](elem: A, data: Vector[X1[A]])(implicit ex1: TypedEncoder[X1[A]]): Prop = {
+ def prop[A: TypedEncoder](
+ elem: A,
+ data: Vector[X1[A]]
+ )(implicit
+ ex1: TypedEncoder[X1[A]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col('a)
@@ -22,7 +27,12 @@ final class FilterTests extends TypedDatasetSuite with Matchers {
}
test("filter('a =!= lit(b))") {
- def prop[A: TypedEncoder](elem: A, data: Vector[X1[A]])(implicit ex1: TypedEncoder[X1[A]]): Prop = {
+ def prop[A: TypedEncoder](
+ elem: A,
+ data: Vector[X1[A]]
+ )(implicit
+ ex1: TypedEncoder[X1[A]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col('a)
@@ -61,13 +71,13 @@ final class FilterTests extends TypedDatasetSuite with Matchers {
}
test("filter('a =!= 'b") {
- def prop[A: TypedEncoder](elem: A, data: Vector[X2[A,A]]): Prop = {
+ def prop[A: TypedEncoder](elem: A, data: Vector[X2[A, A]]): Prop = {
val dataset = TypedDataset.create(data)
val cA = dataset.col('a)
val cB = dataset.col('b)
val dataset2 = dataset.filter(cA =!= cB).collect().run().toVector
- val data2 = data.filter(x => x.a != x.b )
+ val data2 = data.filter(x => x.a != x.b)
(dataset2 ?= data2).&&(dataset.filter(cA =!= cA).count().run() ?= 0)
}
@@ -82,7 +92,8 @@ final class FilterTests extends TypedDatasetSuite with Matchers {
test("filter with arithmetic expressions: addition") {
check(forAll { (data: Vector[X1[Int]]) =>
val ds = TypedDataset.create(data)
- val res = ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector
+ val res =
+ ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector
res ?= data
})
}
@@ -99,21 +110,41 @@ final class FilterTests extends TypedDatasetSuite with Matchers {
val t = X1(1) :: X1(2) :: X1(3) :: Nil
val tds: TypedDataset[X1[Int]] = TypedDataset.create(t)
- assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1)))
- assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1)))
+ assert(
+ tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1))
+ )
+ assert(
+ tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1))
+ )
}
test("Option equality/inequality for columns") {
- def prop[A <: Option[_] : TypedEncoder](a: A, b: A): Prop = {
+ def prop[A <: Option[_]: TypedEncoder](a: A, b: A): Prop = {
val data = X2(a, b) :: X2(a, a) :: Nil
val dataset = TypedDataset.create(data)
val A = dataset.col('a)
val B = dataset.col('b)
- (data.filter(x => x.a == x.b).toSet ?= dataset.filter(A === B).collect().run().toSet).
- &&(data.filter(x => x.a != x.b).toSet ?= dataset.filter(A =!= B).collect().run().toSet).
- &&(data.filter(x => x.a == None).toSet ?= dataset.filter(A.isNone).collect().run().toSet).
- &&(data.filter(x => x.a == None).toSet ?= dataset.filter(A.isNotNone === false).collect().run().toSet)
+ (data
+ .filter(x => x.a == x.b)
+ .toSet ?= dataset.filter(A === B).collect().run().toSet)
+ .&&(
+ data
+ .filter(x => x.a != x.b)
+ .toSet ?= dataset.filter(A =!= B).collect().run().toSet
+ )
+ .&&(
+ data
+ .filter(x => x.a == None)
+ .toSet ?= dataset.filter(A.isNone).collect().run().toSet
+ )
+ .&&(
+ data.filter(x => x.a == None).toSet ?= dataset
+ .filter(A.isNotNone === false)
+ .collect()
+ .run()
+ .toSet
+ )
}
check(forAll(prop[Option[Int]] _))
@@ -126,15 +157,31 @@ final class FilterTests extends TypedDatasetSuite with Matchers {
}
test("Option equality/inequality for lit") {
- def prop[A <: Option[_] : TypedEncoder](a: A, b: A, cLit: A): Prop = {
+ def prop[A <: Option[_]: TypedEncoder](a: A, b: A, cLit: A): Prop = {
val data = X2(a, b) :: X2(a, cLit) :: Nil
val dataset = TypedDataset.create(data)
val colA = dataset.col('a)
- (data.filter(x => x.a == cLit).toSet ?= dataset.filter(colA === cLit).collect().run().toSet).
- &&(data.filter(x => x.a != cLit).toSet ?= dataset.filter(colA =!= cLit).collect().run().toSet).
- &&(data.filter(x => x.a == None).toSet ?= dataset.filter(colA.isNone).collect().run().toSet).
- &&(data.filter(x => x.a == None).toSet ?= dataset.filter(colA.isNotNone === false).collect().run().toSet)
+ (data
+ .filter(x => x.a == cLit)
+ .toSet ?= dataset.filter(colA === cLit).collect().run().toSet)
+ .&&(
+ data
+ .filter(x => x.a != cLit)
+ .toSet ?= dataset.filter(colA =!= cLit).collect().run().toSet
+ )
+ .&&(
+ data
+ .filter(x => x.a == None)
+ .toSet ?= dataset.filter(colA.isNone).collect().run().toSet
+ )
+ .&&(
+ data.filter(x => x.a == None).toSet ?= dataset
+ .filter(colA.isNotNone === false)
+ .collect()
+ .run()
+ .toSet
+ )
}
check(forAll(prop[Option[Int]] _))
@@ -148,7 +195,10 @@ final class FilterTests extends TypedDatasetSuite with Matchers {
}
test("Option content filter") {
- val data = (Option(1L), Option(2L)) :: (Option(0L), Option(1L)) :: (None, None) :: Nil
+ val data = (Option(1L), Option(2L)) :: (Option(0L), Option(1L)) :: (
+ None,
+ None
+ ) :: Nil
val ds = TypedDataset.create(data)
@@ -162,13 +212,20 @@ final class FilterTests extends TypedDatasetSuite with Matchers {
ds.filter(exists).collect().run() shouldEqual Seq(Option(0L) -> Option(1L))
ds.filter(forall).collect().run() shouldEqual Seq(
- Option(0L) -> Option(1L), (None -> None))
+ Option(0L) -> Option(1L),
+ (None -> None)
+ )
}
test("filter with isin values") {
- def prop[A: TypedEncoder](data: Vector[X1[A]], values: Vector[A])(implicit a : CatalystIsin[A]): Prop = {
+ def prop[A: TypedEncoder](
+ data: Vector[X1[A]],
+ values: Vector[A]
+ )(implicit
+ a: CatalystIsin[A]
+ ): Prop = {
val ds = TypedDataset.create(data)
- val res = ds.filter(ds('a).isin(values:_*)).collect().run().toVector
+ val res = ds.filter(ds('a).isin(values: _*)).collect().run().toVector
res ?= data.filter(d => values.contains(d.a))
}
diff --git a/dataset/src/test/scala/frameless/FlattenTests.scala b/dataset/src/test/scala/frameless/FlattenTests.scala
index a65e51b8f..915eb67de 100644
--- a/dataset/src/test/scala/frameless/FlattenTests.scala
+++ b/dataset/src/test/scala/frameless/FlattenTests.scala
@@ -4,18 +4,19 @@ import org.scalacheck.Prop
import org.scalacheck.Prop.forAll
import org.scalacheck.Prop._
-
class FlattenTests extends TypedDatasetSuite {
test("simple flatten test") {
- val ds: TypedDataset[(Int,Option[Int])] = TypedDataset.create(Seq((1,Option(1))))
- ds.flattenOption('_2): TypedDataset[(Int,Int)]
+ val ds: TypedDataset[(Int, Option[Int])] =
+ TypedDataset.create(Seq((1, Option(1))))
+ ds.flattenOption('_2): TypedDataset[(Int, Int)]
}
test("different Optional types") {
def prop[A: TypedEncoder](xs: List[X1[Option[A]]]): Prop = {
val tds: TypedDataset[X1[Option[A]]] = TypedDataset.create(xs)
- val framelessResults: Seq[Tuple1[A]] = tds.flattenOption('a).collect().run().toVector
+ val framelessResults: Seq[Tuple1[A]] =
+ tds.flattenOption('a).collect().run().toVector
val scalaResults = xs.flatMap(_.a).map(Tuple1(_)).toVector
framelessResults ?= scalaResults
diff --git a/dataset/src/test/scala/frameless/GroupByTests.scala b/dataset/src/test/scala/frameless/GroupByTests.scala
index 7178def30..0bc96eb20 100644
--- a/dataset/src/test/scala/frameless/GroupByTests.scala
+++ b/dataset/src/test/scala/frameless/GroupByTests.scala
@@ -7,20 +7,25 @@ import org.scalacheck.Prop._
class GroupByTests extends TypedDatasetSuite {
test("groupByMany('a).agg(sum('b))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder,
- Out: TypedEncoder : Numeric
- ](data: List[X2[A, B]])(
- implicit
- summable: CatalystSummable[B, Out],
- widen: B => Out
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ Out: TypedEncoder: Numeric
+ ](data: List[X2[A, B]]
+ )(implicit
+ summable: CatalystSummable[B, Out],
+ widen: B => Out
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
- val datasetSumByA = dataset.groupByMany(A).agg(sum(B)).collect().run.toVector.sortBy(_._1)
- val sumByA = data.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).map(widen).sum }.toVector.sortBy(_._1)
+ val datasetSumByA =
+ dataset.groupByMany(A).agg(sum(B)).collect().run.toVector.sortBy(_._1)
+ val sumByA = data
+ .groupBy(_.a)
+ .map { case (k, v) => k -> v.map(_.b).map(widen).sum }
+ .toVector
+ .sortBy(_._1)
datasetSumByA ?= sumByA
}
@@ -29,10 +34,11 @@ class GroupByTests extends TypedDatasetSuite {
}
test("agg(sum('a))") {
- def prop[A: TypedEncoder : Numeric](data: List[X1[A]])(
- implicit
- summable: CatalystSummable[A, A]
- ): Prop = {
+ def prop[A: TypedEncoder: Numeric](
+ data: List[X1[A]]
+ )(implicit
+ summable: CatalystSummable[A, A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
@@ -46,14 +52,12 @@ class GroupByTests extends TypedDatasetSuite {
}
test("agg(sum('a), sum('b))") {
- def prop[
- A: TypedEncoder : Numeric,
- B: TypedEncoder : Numeric
- ](data: List[X2[A, B]])(
- implicit
- as: CatalystSummable[A, A],
- bs: CatalystSummable[B, B]
- ): Prop = {
+ def prop[A: TypedEncoder: Numeric, B: TypedEncoder: Numeric](
+ data: List[X2[A, B]]
+ )(implicit
+ as: CatalystSummable[A, A],
+ bs: CatalystSummable[B, B]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -70,21 +74,22 @@ class GroupByTests extends TypedDatasetSuite {
test("agg(sum('a), sum('b), sum('c))") {
def prop[
- A: TypedEncoder : Numeric,
- B: TypedEncoder : Numeric,
- C: TypedEncoder : Numeric
- ](data: List[X3[A, B, C]])(
- implicit
- as: CatalystSummable[A, A],
- bs: CatalystSummable[B, B],
- cs: CatalystSummable[C, C]
- ): Prop = {
+ A: TypedEncoder: Numeric,
+ B: TypedEncoder: Numeric,
+ C: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ )(implicit
+ as: CatalystSummable[A, A],
+ bs: CatalystSummable[B, B],
+ cs: CatalystSummable[C, C]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
val C = dataset.col[C]('c)
- val datasetSum = dataset.agg(sum(A), sum(B), sum(C)).collect().run().toVector
+ val datasetSum =
+ dataset.agg(sum(A), sum(B), sum(C)).collect().run().toVector
val listSumA = data.map(_.a).sum
val listSumB = data.map(_.b).sum
val listSumC = data.map(_.c).sum
@@ -97,30 +102,37 @@ class GroupByTests extends TypedDatasetSuite {
test("agg(sum('a), sum('b), min('c), max('d))") {
def prop[
- A: TypedEncoder : Numeric,
- B: TypedEncoder : Numeric,
- C: TypedEncoder : Numeric,
- D: TypedEncoder : Numeric
- ](data: List[X4[A, B, C, D]])(
- implicit
- as: CatalystSummable[A, A],
- bs: CatalystSummable[B, B],
- co: CatalystOrdered[C],
- fo: CatalystOrdered[D]
- ): Prop = {
+ A: TypedEncoder: Numeric,
+ B: TypedEncoder: Numeric,
+ C: TypedEncoder: Numeric,
+ D: TypedEncoder: Numeric
+ ](data: List[X4[A, B, C, D]]
+ )(implicit
+ as: CatalystSummable[A, A],
+ bs: CatalystSummable[B, B],
+ co: CatalystOrdered[C],
+ fo: CatalystOrdered[D]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
val C = dataset.col[C]('c)
val D = dataset.col[D]('d)
- val datasetSum = dataset.agg(sum(A), sum(B), min(C), max(D)).collect().run().toVector
+ val datasetSum =
+ dataset.agg(sum(A), sum(B), min(C), max(D)).collect().run().toVector
val listSumA = data.map(_.a).sum
val listSumB = data.map(_.b).sum
- val listMinC = if(data.isEmpty) implicitly[Numeric[C]].fromInt(0) else data.map(_.c).min
- val listMaxD = if(data.isEmpty) implicitly[Numeric[D]].fromInt(0) else data.map(_.d).max
-
- datasetSum ?= Vector(if (data.isEmpty) null else (listSumA, listSumB, listMinC, listMaxD))
+ val listMinC =
+ if (data.isEmpty) implicitly[Numeric[C]].fromInt(0)
+ else data.map(_.c).min
+ val listMaxD =
+ if (data.isEmpty) implicitly[Numeric[D]].fromInt(0)
+ else data.map(_.d).max
+
+ datasetSum ?= Vector(
+ if (data.isEmpty) null else (listSumA, listSumB, listMinC, listMaxD)
+ )
}
check(forAll(prop[Long, Long, Long, Int] _))
@@ -130,20 +142,25 @@ class GroupByTests extends TypedDatasetSuite {
test("groupBy('a).agg(sum('b))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder,
- Out: TypedEncoder : Numeric
- ](data: List[X2[A, B]])(
- implicit
- summable: CatalystSummable[B, Out],
- widen: B => Out
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ Out: TypedEncoder: Numeric
+ ](data: List[X2[A, B]]
+ )(implicit
+ summable: CatalystSummable[B, Out],
+ widen: B => Out
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
- val datasetSumByA = dataset.groupBy(A).agg(sum(B)).collect().run.toVector.sortBy(_._1)
- val sumByA = data.groupBy(_.a).mapValues(_.map(_.b).map(widen).sum).toVector.sortBy(_._1)
+ val datasetSumByA =
+ dataset.groupBy(A).agg(sum(B)).collect().run.toVector.sortBy(_._1)
+ val sumByA = data
+ .groupBy(_.a)
+ .mapValues(_.map(_.b).map(widen).sum)
+ .toVector
+ .sortBy(_._1)
datasetSumByA ?= sumByA
}
@@ -152,18 +169,23 @@ class GroupByTests extends TypedDatasetSuite {
}
test("groupBy('a).mapGroups('a, sum('b))") {
- def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Numeric
- ](data: List[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric](
+ data: List[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
- val datasetSumByA = dataset.groupBy(A)
- .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) }
- .collect().run().toVector.sortBy(_._1)
- val sumByA = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1)
+ val datasetSumByA = dataset
+ .groupBy(A)
+ .deserialized
+ .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) }
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
+ val sumByA =
+ data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1)
datasetSumByA ?= sumByA
}
@@ -173,18 +195,18 @@ class GroupByTests extends TypedDatasetSuite {
test("groupBy('a).agg(sum('b), sum('c)) to groupBy('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder,
- C: TypedEncoder,
- OutB: TypedEncoder : Numeric,
- OutC: TypedEncoder : Numeric
- ](data: List[X3[A, B, C]])(
- implicit
- summableB: CatalystSummable[B, OutB],
- summableC: CatalystSummable[C, OutC],
- widenb: B => OutB,
- widenc: C => OutC
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ OutB: TypedEncoder: Numeric,
+ OutC: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ )(implicit
+ summableB: CatalystSummable[B, OutB],
+ summableC: CatalystSummable[C, OutC],
+ widenb: B => OutB,
+ widenc: C => OutC
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -193,46 +215,85 @@ class GroupByTests extends TypedDatasetSuite {
val framelessSumBC = dataset
.groupBy(A)
.agg(sum(B), sum(C))
- .collect().run.toVector.sortBy(_._1)
-
- val scalaSumBC = data.groupBy(_.a).mapValues { xs =>
- (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum)
- }.toVector.map {
- case (a, (b, c)) => (a, b, c)
- }.sortBy(_._1)
+ .collect()
+ .run
+ .toVector
+ .sortBy(_._1)
+
+ val scalaSumBC = data
+ .groupBy(_.a)
+ .mapValues { xs =>
+ (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum)
+ }
+ .toVector
+ .map { case (a, (b, c)) => (a, b, c) }
+ .sortBy(_._1)
val framelessSumBCB = dataset
.groupBy(A)
.agg(sum(B), sum(C), sum(B))
- .collect().run.toVector.sortBy(_._1)
-
- val scalaSumBCB = data.groupBy(_.a).mapValues { xs =>
- (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum)
- }.toVector.map {
- case (a, (b1, c, b2)) => (a, b1, c, b2)
- }.sortBy(_._1)
+ .collect()
+ .run
+ .toVector
+ .sortBy(_._1)
+
+ val scalaSumBCB = data
+ .groupBy(_.a)
+ .mapValues { xs =>
+ (
+ xs.map(_.b).map(widenb).sum,
+ xs.map(_.c).map(widenc).sum,
+ xs.map(_.b).map(widenb).sum
+ )
+ }
+ .toVector
+ .map { case (a, (b1, c, b2)) => (a, b1, c, b2) }
+ .sortBy(_._1)
val framelessSumBCBC = dataset
.groupBy(A)
.agg(sum(B), sum(C), sum(B), sum(C))
- .collect().run.toVector.sortBy(_._1)
-
- val scalaSumBCBC = data.groupBy(_.a).mapValues { xs =>
- (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum)
- }.toVector.map {
- case (a, (b1, c1, b2, c2)) => (a, b1, c1, b2, c2)
- }.sortBy(_._1)
+ .collect()
+ .run
+ .toVector
+ .sortBy(_._1)
+
+ val scalaSumBCBC = data
+ .groupBy(_.a)
+ .mapValues { xs =>
+ (
+ xs.map(_.b).map(widenb).sum,
+ xs.map(_.c).map(widenc).sum,
+ xs.map(_.b).map(widenb).sum,
+ xs.map(_.c).map(widenc).sum
+ )
+ }
+ .toVector
+ .map { case (a, (b1, c1, b2, c2)) => (a, b1, c1, b2, c2) }
+ .sortBy(_._1)
val framelessSumBCBCB = dataset
.groupBy(A)
.agg(sum(B), sum(C), sum(B), sum(C), sum(B))
- .collect().run.toVector.sortBy(_._1)
-
- val scalaSumBCBCB = data.groupBy(_.a).mapValues { xs =>
- (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum)
- }.toVector.map {
- case (a, (b1, c1, b2, c2, b3)) => (a, b1, c1, b2, c2, b3)
- }.sortBy(_._1)
+ .collect()
+ .run
+ .toVector
+ .sortBy(_._1)
+
+ val scalaSumBCBCB = data
+ .groupBy(_.a)
+ .mapValues { xs =>
+ (
+ xs.map(_.b).map(widenb).sum,
+ xs.map(_.c).map(widenc).sum,
+ xs.map(_.b).map(widenb).sum,
+ xs.map(_.c).map(widenc).sum,
+ xs.map(_.b).map(widenb).sum
+ )
+ }
+ .toVector
+ .map { case (a, (b1, c1, b2, c2, b3)) => (a, b1, c1, b2, c2, b3) }
+ .sortBy(_._1)
(framelessSumBC ?= scalaSumBC)
.&&(framelessSumBCB ?= scalaSumBCB)
@@ -245,70 +306,110 @@ class GroupByTests extends TypedDatasetSuite {
test("groupBy('a, 'b).agg(sum('c)) to groupBy('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder,
- OutC: TypedEncoder: Numeric
- ](data: List[X3[A, B, C]])(
- implicit
- summableC: CatalystSummable[C, OutC],
- widenc: C => OutC
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder,
+ OutC: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ )(implicit
+ summableC: CatalystSummable[C, OutC],
+ widenc: C => OutC
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
val C = dataset.col[C]('c)
val framelessSumC = dataset
- .groupBy(A,B)
+ .groupBy(A, B)
.agg(sum(C))
- .collect().run.toVector.sortBy(x => (x._1,x._2))
-
- val scalaSumC = data.groupBy(x => (x.a,x.b)).mapValues { xs =>
- xs.map(_.c).map(widenc).sum
- }.toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1,x._2))
+ .collect()
+ .run
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val scalaSumC = data
+ .groupBy(x => (x.a, x.b))
+ .mapValues { xs => xs.map(_.c).map(widenc).sum }
+ .toVector
+ .map { case ((a, b), c) => (a, b, c) }
+ .sortBy(x => (x._1, x._2))
val framelessSumCC = dataset
- .groupBy(A,B)
+ .groupBy(A, B)
.agg(sum(C), sum(C))
- .collect().run.toVector.sortBy(x => (x._1,x._2))
-
- val scalaSumCC = data.groupBy(x => (x.a,x.b)).mapValues { xs =>
- val s = xs.map(_.c).map(widenc).sum; (s,s)
- }.toVector.map { case ((a, b), (c1, c2)) => (a, b, c1, c2) }.sortBy(x => (x._1,x._2))
+ .collect()
+ .run
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val scalaSumCC = data
+ .groupBy(x => (x.a, x.b))
+ .mapValues { xs =>
+ val s = xs.map(_.c).map(widenc).sum; (s, s)
+ }
+ .toVector
+ .map { case ((a, b), (c1, c2)) => (a, b, c1, c2) }
+ .sortBy(x => (x._1, x._2))
val framelessSumCCC = dataset
- .groupBy(A,B)
+ .groupBy(A, B)
.agg(sum(C), sum(C), sum(C))
- .collect().run.toVector.sortBy(x => (x._1,x._2))
-
- val scalaSumCCC = data.groupBy(x => (x.a,x.b)).mapValues { xs =>
- val s = xs.map(_.c).map(widenc).sum; (s,s,s)
- }.toVector.map { case ((a, b), (c1, c2, c3)) => (a, b, c1, c2, c3) }.sortBy(x => (x._1,x._2))
+ .collect()
+ .run
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val scalaSumCCC = data
+ .groupBy(x => (x.a, x.b))
+ .mapValues { xs =>
+ val s = xs.map(_.c).map(widenc).sum; (s, s, s)
+ }
+ .toVector
+ .map { case ((a, b), (c1, c2, c3)) => (a, b, c1, c2, c3) }
+ .sortBy(x => (x._1, x._2))
val framelessSumCCCC = dataset
- .groupBy(A,B)
+ .groupBy(A, B)
.agg(sum(C), sum(C), sum(C), sum(C))
- .collect().run.toVector.sortBy(x => (x._1,x._2))
-
- val scalaSumCCCC = data.groupBy(x => (x.a,x.b)).mapValues { xs =>
- val s = xs.map(_.c).map(widenc).sum; (s,s,s,s)
- }.toVector.map { case ((a, b), (c1, c2, c3, c4)) => (a, b, c1, c2, c3, c4) }.sortBy(x => (x._1,x._2))
+ .collect()
+ .run
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val scalaSumCCCC = data
+ .groupBy(x => (x.a, x.b))
+ .mapValues { xs =>
+ val s = xs.map(_.c).map(widenc).sum; (s, s, s, s)
+ }
+ .toVector
+ .map { case ((a, b), (c1, c2, c3, c4)) => (a, b, c1, c2, c3, c4) }
+ .sortBy(x => (x._1, x._2))
val framelessSumCCCCC = dataset
- .groupBy(A,B)
+ .groupBy(A, B)
.agg(sum(C), sum(C), sum(C), sum(C), sum(C))
- .collect().run.toVector.sortBy(x => (x._1,x._2))
-
- val scalaSumCCCCC = data.groupBy(x => (x.a,x.b)).mapValues { xs =>
- val s = xs.map(_.c).map(widenc).sum; (s,s,s,s,s)
- }.toVector.map { case ((a, b), (c1, c2, c3, c4, c5)) => (a, b, c1, c2, c3, c4, c5) }.sortBy(x => (x._1,x._2))
+ .collect()
+ .run
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val scalaSumCCCCC = data
+ .groupBy(x => (x.a, x.b))
+ .mapValues { xs =>
+ val s = xs.map(_.c).map(widenc).sum; (s, s, s, s, s)
+ }
+ .toVector
+ .map {
+ case ((a, b), (c1, c2, c3, c4, c5)) => (a, b, c1, c2, c3, c4, c5)
+ }
+ .sortBy(x => (x._1, x._2))
(framelessSumC ?= scalaSumC) &&
- (framelessSumCC ?= scalaSumCC) &&
- (framelessSumCCC ?= scalaSumCCC) &&
- (framelessSumCCCC ?= scalaSumCCCC) &&
- (framelessSumCCCCC ?= scalaSumCCCCC)
+ (framelessSumCC ?= scalaSumCC) &&
+ (framelessSumCCC ?= scalaSumCCC) &&
+ (framelessSumCCCC ?= scalaSumCCCC) &&
+ (framelessSumCCCCC ?= scalaSumCCCCC)
}
check(forAll(prop[String, Long, BigDecimal, BigDecimal] _))
@@ -316,19 +417,19 @@ class GroupByTests extends TypedDatasetSuite {
test("groupBy('a, 'b).agg(sum('c), sum('d))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder,
- D: TypedEncoder,
- OutC: TypedEncoder : Numeric,
- OutD: TypedEncoder : Numeric
- ](data: List[X4[A, B, C, D]])(
- implicit
- summableC: CatalystSummable[C, OutC],
- summableD: CatalystSummable[D, OutD],
- widenc: C => OutC,
- widend: D => OutD
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder,
+ D: TypedEncoder,
+ OutC: TypedEncoder: Numeric,
+ OutD: TypedEncoder: Numeric
+ ](data: List[X4[A, B, C, D]]
+ )(implicit
+ summableC: CatalystSummable[C, OutC],
+ summableD: CatalystSummable[D, OutD],
+ widenc: C => OutC,
+ widend: D => OutD
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -338,13 +439,19 @@ class GroupByTests extends TypedDatasetSuite {
val datasetSumByAB = dataset
.groupBy(A, B)
.agg(sum(C), sum(D))
- .collect().run.toVector.sortBy(x => (x._1, x._2))
-
- val sumByAB = data.groupBy(x => (x.a, x.b)).mapValues { xs =>
- (xs.map(_.c).map(widenc).sum, xs.map(_.d).map(widend).sum)
- }.toVector.map {
- case ((a, b), (c, d)) => (a, b, c, d)
- }.sortBy(x => (x._1, x._2))
+ .collect()
+ .run
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val sumByAB = data
+ .groupBy(x => (x.a, x.b))
+ .mapValues { xs =>
+ (xs.map(_.c).map(widenc).sum, xs.map(_.d).map(widend).sum)
+ }
+ .toVector
+ .map { case ((a, b), (c, d)) => (a, b, c, d) }
+ .sortBy(x => (x._1, x._2))
datasetSumByAB ?= sumByAB
}
@@ -354,10 +461,11 @@ class GroupByTests extends TypedDatasetSuite {
test("groupBy('a, 'b).mapGroups('a, 'b, sum('c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder : Numeric
- ](data: List[X3[A, B, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -365,12 +473,19 @@ class GroupByTests extends TypedDatasetSuite {
val datasetSumByAB = dataset
.groupBy(A, B)
- .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) }
- .collect().run().toVector.sortBy(x => (x._1, x._2))
-
- val sumByAB = data.groupBy(x => (x.a, x.b))
+ .deserialized
+ .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) }
+ .collect()
+ .run()
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val sumByAB = data
+ .groupBy(x => (x.a, x.b))
.mapValues { xs => xs.map(_.c).sum }
- .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2))
+ .toVector
+ .map { case ((a, b), c) => (a, b, c) }
+ .sortBy(x => (x._1, x._2))
datasetSumByAB ?= sumByAB
}
@@ -379,17 +494,19 @@ class GroupByTests extends TypedDatasetSuite {
}
test("groupBy('a).mapGroups(('a, toVector(('a, 'b))") {
- def prop[
- A: TypedEncoder: Ordering,
- B: TypedEncoder: Ordering
- ](data: Vector[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ data: Vector[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val datasetGrouped = dataset
.groupBy(A)
- .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted))
- .collect().run.toMap
+ .deserialized
+ .mapGroups((a, xs) => (a, xs.toVector.sorted))
+ .collect()
+ .run
+ .toMap
val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted }
@@ -402,21 +519,23 @@ class GroupByTests extends TypedDatasetSuite {
}
test("groupBy('a).flatMapGroups(('a, toVector(('a, 'b))") {
- def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering
- ](data: Vector[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ data: Vector[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val datasetGrouped = dataset
.groupBy(A)
- .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x)))
- .collect().run
+ .deserialized
+ .flatMapGroups((a, xs) => xs.map(x => (a, x)))
+ .collect()
+ .run
.sorted
val dataGrouped = data
- .groupBy(_.a).toSeq
+ .groupBy(_.a)
+ .toSeq
.flatMap { case (a, xs) => xs.map(x => (a, x)) }
.sorted
@@ -430,22 +549,26 @@ class GroupByTests extends TypedDatasetSuite {
test("groupBy('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder : Ordering
- ](data: Vector[X3[A, B, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](data: Vector[X3[A, B, C]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val cA = dataset.col[A]('a)
val cB = dataset.col[B]('b)
val datasetGrouped = dataset
.groupBy(cA, cB)
- .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x)))
- .collect().run()
+ .deserialized
+ .flatMapGroups((a, xs) => xs.map(x => (a, x)))
+ .collect()
+ .run()
.sorted
val dataGrouped = data
- .groupBy(t => (t.a,t.b)).toSeq
+ .groupBy(t => (t.a, t.b))
+ .toSeq
.flatMap { case (a, xs) => xs.map(x => (a, x)) }
.sorted
diff --git a/dataset/src/test/scala/frameless/InjectionTests.scala b/dataset/src/test/scala/frameless/InjectionTests.scala
index c17a52bd7..9ee252616 100644
--- a/dataset/src/test/scala/frameless/InjectionTests.scala
+++ b/dataset/src/test/scala/frameless/InjectionTests.scala
@@ -10,6 +10,7 @@ case object France extends Country
case object Russia extends Country
object Country {
+
implicit val arbitrary: Arbitrary[Country] =
Arbitrary(Arbitrary.arbitrary[Boolean].map(injection.invert))
@@ -23,15 +24,18 @@ case object Pasta extends Food
case object Rice extends Food
object Food {
+
implicit val arbitrary: Arbitrary[Food] =
- Arbitrary(Arbitrary.arbitrary[Int].map(i => injection.invert(Math.abs(i % 3))))
+ Arbitrary(
+ Arbitrary.arbitrary[Int].map(i => injection.invert(Math.abs(i % 3)))
+ )
implicit val injection: Injection[Food, Int] =
Injection(
{
case Burger => 0
- case Pasta => 1
- case Rice => 2
+ case Pasta => 1
+ case Rice => 2
},
{
case 0 => Burger
@@ -46,10 +50,13 @@ class LocalDateTime {
var instant: Long = _
override def equals(o: Any): Boolean =
- o.isInstanceOf[LocalDateTime] && o.asInstanceOf[LocalDateTime].instant == instant
+ o.isInstanceOf[LocalDateTime] && o
+ .asInstanceOf[LocalDateTime]
+ .instant == instant
}
object LocalDateTime {
+
implicit val arbitrary: Arbitrary[LocalDateTime] =
Arbitrary(Arbitrary.arbitrary[Long].map(injection.invert))
@@ -76,8 +83,12 @@ case class I[A](value: A)
object I {
implicit def injection[A]: Injection[I[A], A] = Injection(_.value, I(_))
- implicit def typedEncoder[A: TypedEncoder]: TypedEncoder[I[A]] = TypedEncoder.usingInjection[I[A], A]
- implicit def arbitrary[A: Arbitrary]: Arbitrary[I[A]] = Arbitrary(Arbitrary.arbitrary[A].map(I(_)))
+
+ implicit def typedEncoder[A: TypedEncoder]: TypedEncoder[I[A]] =
+ TypedEncoder.usingInjection[I[A], A]
+
+ implicit def arbitrary[A: Arbitrary]: Arbitrary[I[A]] =
+ Arbitrary(Arbitrary.arbitrary[A].map(I(_)))
}
sealed trait Employee
@@ -86,6 +97,7 @@ case object PartTime extends Employee
case object FullTime extends Employee
object Employee {
+
implicit val arbitrary: Arbitrary[Employee] =
Arbitrary(Gen.oneOf(Casual, PartTime, FullTime))
}
@@ -95,6 +107,7 @@ case object Nothing extends Maybe
case class Just(get: Int) extends Maybe
sealed trait Switch
+
object Switch {
case object Off extends Switch
case object On extends Switch
@@ -109,6 +122,7 @@ case class Green() extends Pixel
case class Blue() extends Pixel
object Pixel {
+
implicit val arbitrary: Arbitrary[Pixel] =
Arbitrary(Gen.oneOf(Red(), Green(), Blue()))
}
@@ -118,6 +132,7 @@ case object Closed extends Connection[Nothing]
case object Open extends Connection[Nothing]
object Connection {
+
implicit def arbitrary[A]: Arbitrary[Connection[A]] =
Arbitrary(Gen.oneOf(Closed, Open))
}
@@ -127,6 +142,7 @@ case object Car extends Vehicle("red")
case object Bike extends Vehicle("blue")
object Vehicle {
+
implicit val arbitrary: Arbitrary[Vehicle] =
Arbitrary(Gen.oneOf(Car, Bike))
}
@@ -159,7 +175,9 @@ class InjectionTests extends TypedDatasetSuite {
check(forAll(prop[Option[I[X1[Int]]]] _))
assert(TypedEncoder[I[Int]].catalystRepr == TypedEncoder[Int].catalystRepr)
- assert(TypedEncoder[I[I[Int]]].catalystRepr == TypedEncoder[Int].catalystRepr)
+ assert(
+ TypedEncoder[I[I[Int]]].catalystRepr == TypedEncoder[Int].catalystRepr
+ )
assert(TypedEncoder[I[Option[Int]]].nullable)
}
@@ -176,12 +194,18 @@ class InjectionTests extends TypedDatasetSuite {
check(forAll(prop[X2[Person, Person]] _))
check(forAll(prop[Person] _))
- assert(TypedEncoder[Person].catalystRepr == TypedEncoder[(Int, String)].catalystRepr)
+ assert(
+ TypedEncoder[Person].catalystRepr == TypedEncoder[
+ (Int, String)
+ ].catalystRepr
+ )
}
test("Resolve ambiguity by importing usingDerivation") {
import TypedEncoder.usingDerivation
- assert(implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]])
+ assert(
+ implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]]
+ )
check(forAll(prop[Person] _))
}
@@ -200,7 +224,9 @@ class InjectionTests extends TypedDatasetSuite {
check(forAll(prop[X2[Employee, Employee]] _))
check(forAll(prop[Employee] _))
- assert(TypedEncoder[Employee].catalystRepr == TypedEncoder[String].catalystRepr)
+ assert(
+ TypedEncoder[Employee].catalystRepr == TypedEncoder[String].catalystRepr
+ )
}
test("TypedEncoder[Maybe] cannot be derived") {
@@ -220,7 +246,9 @@ class InjectionTests extends TypedDatasetSuite {
check(forAll(prop[X2[Switch, Switch]] _))
check(forAll(prop[Switch] _))
- assert(TypedEncoder[Switch].catalystRepr == TypedEncoder[String].catalystRepr)
+ assert(
+ TypedEncoder[Switch].catalystRepr == TypedEncoder[String].catalystRepr
+ )
}
test("Derive encoder for type with data constructors defined as parameterless case classes") {
@@ -231,7 +259,9 @@ class InjectionTests extends TypedDatasetSuite {
check(forAll(prop[X2[Pixel, Pixel]] _))
check(forAll(prop[Pixel] _))
- assert(TypedEncoder[Pixel].catalystRepr == TypedEncoder[String].catalystRepr)
+ assert(
+ TypedEncoder[Pixel].catalystRepr == TypedEncoder[String].catalystRepr
+ )
}
test("Derive encoder for phantom type") {
@@ -242,7 +272,11 @@ class InjectionTests extends TypedDatasetSuite {
check(forAll(prop[X2[Connection[Int], Connection[Int]]] _))
check(forAll(prop[Connection[Int]] _))
- assert(TypedEncoder[Connection[Int]].catalystRepr == TypedEncoder[String].catalystRepr)
+ assert(
+ TypedEncoder[Connection[Int]].catalystRepr == TypedEncoder[
+ String
+ ].catalystRepr
+ )
}
test("Derive encoder for ADT with abstract class as the base type") {
@@ -253,26 +287,36 @@ class InjectionTests extends TypedDatasetSuite {
check(forAll(prop[X2[Vehicle, Vehicle]] _))
check(forAll(prop[Vehicle] _))
- assert(TypedEncoder[Vehicle].catalystRepr == TypedEncoder[String].catalystRepr)
+ assert(
+ TypedEncoder[Vehicle].catalystRepr == TypedEncoder[String].catalystRepr
+ )
}
- test("apply method of derived Injection instance produces the correct string") {
+ test(
+ "apply method of derived Injection instance produces the correct string"
+ ) {
import frameless.TypedEncoder.injections._
assert(implicitly[Injection[Employee, String]].apply(Casual) === "Casual")
assert(implicitly[Injection[Switch, String]].apply(Switch.On) === "On")
assert(implicitly[Injection[Pixel, String]].apply(Blue()) === "Blue")
- assert(implicitly[Injection[Connection[Int], String]].apply(Open) === "Open")
+ assert(
+ implicitly[Injection[Connection[Int], String]].apply(Open) === "Open"
+ )
assert(implicitly[Injection[Vehicle, String]].apply(Bike) === "Bike")
}
- test("invert method of derived Injection instance produces the correct value") {
+ test(
+ "invert method of derived Injection instance produces the correct value"
+ ) {
import frameless.TypedEncoder.injections._
assert(implicitly[Injection[Employee, String]].invert("Casual") === Casual)
assert(implicitly[Injection[Switch, String]].invert("On") === Switch.On)
assert(implicitly[Injection[Pixel, String]].invert("Blue") === Blue())
- assert(implicitly[Injection[Connection[Int], String]].invert("Open") === Open)
+ assert(
+ implicitly[Injection[Connection[Int], String]].invert("Open") === Open
+ )
assert(implicitly[Injection[Vehicle, String]].invert("Bike") === Bike)
}
diff --git a/dataset/src/test/scala/frameless/JobTests.scala b/dataset/src/test/scala/frameless/JobTests.scala
index 9650a020f..8ef20970c 100644
--- a/dataset/src/test/scala/frameless/JobTests.scala
+++ b/dataset/src/test/scala/frameless/JobTests.scala
@@ -6,13 +6,20 @@ import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks
import org.scalatest.freespec.AnyFreeSpec
import org.scalatest.matchers.should.Matchers
-
-class JobTests extends AnyFreeSpec with BeforeAndAfterAll with SparkTesting with ScalaCheckDrivenPropertyChecks with Matchers {
+class JobTests
+ extends AnyFreeSpec
+ with BeforeAndAfterAll
+ with SparkTesting
+ with ScalaCheckDrivenPropertyChecks
+ with Matchers {
"map" - {
"identity" in {
- def check[T](implicit arb: Arbitrary[T]) = forAll {
- t: T => Job(t).map(identity).run() shouldEqual Job(t).run()
+ def check[T](
+ implicit
+ arb: Arbitrary[T]
+ ) = forAll { t: T =>
+ Job(t).map(identity).run() shouldEqual Job(t).run()
}
check[Int]
@@ -21,8 +28,8 @@ class JobTests extends AnyFreeSpec with BeforeAndAfterAll with SparkTesting with
val f1: Int => Int = _ + 1
val f2: Int => Int = (i: Int) => i * i
- "composition" in forAll {
- i: Int => Job(i).map(f1).map(f2).run() shouldEqual Job(i).map(f1 andThen f2).run()
+ "composition" in forAll { i: Int =>
+ Job(i).map(f1).map(f2).run() shouldEqual Job(i).map(f1 andThen f2).run()
}
}
@@ -30,25 +37,26 @@ class JobTests extends AnyFreeSpec with BeforeAndAfterAll with SparkTesting with
val f1: Int => Job[Int] = (i: Int) => Job(i + 1)
val f2: Int => Job[Int] = (i: Int) => Job(i * i)
- "left identity" in forAll {
- i: Int => Job(i).flatMap(f1).run() shouldEqual f1(i).run()
+ "left identity" in forAll { i: Int =>
+ Job(i).flatMap(f1).run() shouldEqual f1(i).run()
}
- "right identity" in forAll {
- i: Int => Job(i).flatMap(i => Job.apply(i)).run() shouldEqual Job(i).run()
+ "right identity" in forAll { i: Int =>
+ Job(i).flatMap(i => Job.apply(i)).run() shouldEqual Job(i).run()
}
- "associativity" in forAll {
- i: Int => Job(i).flatMap(f1).flatMap(f2).run() shouldEqual Job(i).flatMap(ii => f1(ii).flatMap(f2)).run()
+ "associativity" in forAll { i: Int =>
+ Job(i).flatMap(f1).flatMap(f2).run() shouldEqual Job(i)
+ .flatMap(ii => f1(ii).flatMap(f2))
+ .run()
}
}
"properties" - {
- "read back" in forAll {
- (k:String, v: String) =>
- val scopedKey = "frameless.tests." + k
- Job(1).withLocalProperty(scopedKey,v).run()
- sc.getLocalProperty(scopedKey) shouldBe v
+ "read back" in forAll { (k: String, v: String) =>
+ val scopedKey = "frameless.tests." + k
+ Job(1).withLocalProperty(scopedKey, v).run()
+ sc.getLocalProperty(scopedKey) shouldBe v
}
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/JoinTests.scala b/dataset/src/test/scala/frameless/JoinTests.scala
index b34911c4f..a07ce4b35 100644
--- a/dataset/src/test/scala/frameless/JoinTests.scala
+++ b/dataset/src/test/scala/frameless/JoinTests.scala
@@ -1,20 +1,21 @@
package frameless
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{ StructField, StructType }
import org.scalacheck.Prop
import org.scalacheck.Prop._
class JoinTests extends TypedDatasetSuite {
test("ab.joinCross(ac)") {
def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering,
- C : TypedEncoder : Ordering
- ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](left: List[X2[A, B]],
+ right: List[X2[A, C]]
+ ): Prop = {
val leftDs = TypedDataset.create(left)
val rightDs = TypedDataset.create(right)
- val joinedDs = leftDs
- .joinCross(rightDs)
+ val joinedDs = leftDs.joinCross(rightDs)
val joinedData = joinedDs.collect().run().toVector.sorted
@@ -25,9 +26,12 @@ class JoinTests extends TypedDatasetSuite {
} yield (ab, ac)
}.toVector
- val equalSchemas = joinedDs.schema ?= StructType(Seq(
- StructField("_1", leftDs.schema, nullable = false),
- StructField("_2", rightDs.schema, nullable = false)))
+ val equalSchemas = joinedDs.schema ?= StructType(
+ Seq(
+ StructField("_1", leftDs.schema, nullable = false),
+ StructField("_2", rightDs.schema, nullable = false)
+ )
+ )
(joined.sorted ?= joinedData) && equalSchemas
}
@@ -37,19 +41,21 @@ class JoinTests extends TypedDatasetSuite {
test("ab.joinFull(ac)(ab.a == ac.a)") {
def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering,
- C : TypedEncoder : Ordering
- ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](left: List[X2[A, B]],
+ right: List[X2[A, C]]
+ ): Prop = {
val leftDs = TypedDataset.create(left)
val rightDs = TypedDataset.create(right)
- val joinedDs = leftDs
- .joinFull(rightDs)(leftDs.col('a) === rightDs.col('a))
+ val joinedDs =
+ leftDs.joinFull(rightDs)(leftDs.col('a) === rightDs.col('a))
val joinedData = joinedDs.collect().run().toVector.sorted
val rightKeys = right.map(_.a).toSet
- val leftKeys = left.map(_.a).toSet
+ val leftKeys = left.map(_.a).toSet
val joined = {
for {
ab <- left
@@ -65,9 +71,12 @@ class JoinTests extends TypedDatasetSuite {
} yield (None, Some(ac))
}.toVector
- val equalSchemas = joinedDs.schema ?= StructType(Seq(
- StructField("_1", leftDs.schema, nullable = true),
- StructField("_2", rightDs.schema, nullable = true)))
+ val equalSchemas = joinedDs.schema ?= StructType(
+ Seq(
+ StructField("_1", leftDs.schema, nullable = true),
+ StructField("_2", rightDs.schema, nullable = true)
+ )
+ )
(joined.sorted ?= joinedData) && equalSchemas
}
@@ -77,14 +86,16 @@ class JoinTests extends TypedDatasetSuite {
test("ab.joinInner(ac)(ab.a == ac.a)") {
def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering,
- C : TypedEncoder : Ordering
- ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](left: List[X2[A, B]],
+ right: List[X2[A, C]]
+ ): Prop = {
val leftDs = TypedDataset.create(left)
val rightDs = TypedDataset.create(right)
- val joinedDs = leftDs
- .joinInner(rightDs)(leftDs.col('a) === rightDs.col('a))
+ val joinedDs =
+ leftDs.joinInner(rightDs)(leftDs.col('a) === rightDs.col('a))
val joinedData = joinedDs.collect().run().toVector.sorted
@@ -95,9 +106,12 @@ class JoinTests extends TypedDatasetSuite {
} yield (ab, ac)
}.toVector
- val equalSchemas = joinedDs.schema ?= StructType(Seq(
- StructField("_1", leftDs.schema, nullable = false),
- StructField("_2", rightDs.schema, nullable = false)))
+ val equalSchemas = joinedDs.schema ?= StructType(
+ Seq(
+ StructField("_1", leftDs.schema, nullable = false),
+ StructField("_2", rightDs.schema, nullable = false)
+ )
+ )
(joined.sorted ?= joinedData) && equalSchemas
}
@@ -107,14 +121,16 @@ class JoinTests extends TypedDatasetSuite {
test("ab.joinLeft(ac)(ab.a == ac.a)") {
def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering,
- C : TypedEncoder : Ordering
- ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](left: List[X2[A, B]],
+ right: List[X2[A, C]]
+ ): Prop = {
val leftDs = TypedDataset.create(left)
val rightDs = TypedDataset.create(right)
- val joinedDs = leftDs
- .joinLeft(rightDs)(leftDs.col('a) === rightDs.col('a))
+ val joinedDs =
+ leftDs.joinLeft(rightDs)(leftDs.col('a) === rightDs.col('a))
val joinedData = joinedDs.collect().run().toVector.sorted
@@ -130,11 +146,16 @@ class JoinTests extends TypedDatasetSuite {
} yield (ab, None)
}.toVector
- val equalSchemas = joinedDs.schema ?= StructType(Seq(
- StructField("_1", leftDs.schema, nullable = false),
- StructField("_2", rightDs.schema, nullable = true)))
+ val equalSchemas = joinedDs.schema ?= StructType(
+ Seq(
+ StructField("_1", leftDs.schema, nullable = false),
+ StructField("_2", rightDs.schema, nullable = true)
+ )
+ )
- (joined.sorted ?= joinedData) && (joinedData.map(_._1).toSet ?= left.toSet) && equalSchemas
+ (joined.sorted ?= joinedData) && (joinedData
+ .map(_._1)
+ .toSet ?= left.toSet) && equalSchemas
}
check(forAll(prop[Int, Long, String] _))
@@ -142,15 +163,17 @@ class JoinTests extends TypedDatasetSuite {
test("ab.joinLeftAnti(ac)(ab.a == ac.a)") {
def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering,
- C : TypedEncoder : Ordering
- ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](left: List[X2[A, B]],
+ right: List[X2[A, C]]
+ ): Prop = {
val leftDs = TypedDataset.create(left)
val rightDs = TypedDataset.create(right)
val rightKeys = right.map(_.a).toSet
- val joinedDs = leftDs
- .joinLeftAnti(rightDs)(leftDs.col('a) === rightDs.col('a))
+ val joinedDs =
+ leftDs.joinLeftAnti(rightDs)(leftDs.col('a) === rightDs.col('a))
val joinedData = joinedDs.collect().run().toVector.sorted
@@ -170,15 +193,17 @@ class JoinTests extends TypedDatasetSuite {
test("ab.joinLeftSemi(ac)(ab.a == ac.a)") {
def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering,
- C : TypedEncoder : Ordering
- ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](left: List[X2[A, B]],
+ right: List[X2[A, C]]
+ ): Prop = {
val leftDs = TypedDataset.create(left)
val rightDs = TypedDataset.create(right)
val rightKeys = right.map(_.a).toSet
- val joinedDs = leftDs
- .joinLeftSemi(rightDs)(leftDs.col('a) === rightDs.col('a))
+ val joinedDs =
+ leftDs.joinLeftSemi(rightDs)(leftDs.col('a) === rightDs.col('a))
val joinedData = joinedDs.collect().run().toVector.sorted
@@ -198,14 +223,16 @@ class JoinTests extends TypedDatasetSuite {
test("ab.joinRight(ac)(ab.a == ac.a)") {
def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering,
- C : TypedEncoder : Ordering
- ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](left: List[X2[A, B]],
+ right: List[X2[A, C]]
+ ): Prop = {
val leftDs = TypedDataset.create(left)
val rightDs = TypedDataset.create(right)
- val joinedDs = leftDs
- .joinRight(rightDs)(leftDs.col('a) === rightDs.col('a))
+ val joinedDs =
+ leftDs.joinRight(rightDs)(leftDs.col('a) === rightDs.col('a))
val joinedData = joinedDs.collect().run().toVector.sorted
@@ -221,11 +248,16 @@ class JoinTests extends TypedDatasetSuite {
} yield (None, ac)
}.toVector
- val equalSchemas = joinedDs.schema ?= StructType(Seq(
- StructField("_1", leftDs.schema, nullable = true),
- StructField("_2", rightDs.schema, nullable = false)))
+ val equalSchemas = joinedDs.schema ?= StructType(
+ Seq(
+ StructField("_1", leftDs.schema, nullable = true),
+ StructField("_2", rightDs.schema, nullable = false)
+ )
+ )
- (joined.sorted ?= joinedData) && (joinedData.map(_._2).toSet ?= right.toSet) && equalSchemas
+ (joined.sorted ?= joinedData) && (joinedData
+ .map(_._2)
+ .toSet ?= right.toSet) && equalSchemas
}
check(forAll(prop[Int, Long, String] _))
diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala
index 50df45220..dd1d282d3 100644
--- a/dataset/src/test/scala/frameless/LitTests.scala
+++ b/dataset/src/test/scala/frameless/LitTests.scala
@@ -9,22 +9,26 @@ import org.scalacheck.Prop, Prop._
import RecordEncoderTests.Name
class LitTests extends TypedDatasetSuite with Matchers {
- def prop[A: TypedEncoder](value: A)(implicit i0: shapeless.Refute[IsValueClass[A]]): Prop = {
+
+ def prop[A: TypedEncoder](
+ value: A
+ )(implicit
+ i0: shapeless.Refute[IsValueClass[A]]
+ ): Prop = {
val df: TypedDataset[Int] = TypedDataset.create(1 :: Nil)
val l: TypedColumn[Int, A] = lit(value)
// filter forces whole codegen
- val elems = df.deserialized.filter((_:Int) => true).select(l)
+ val elems = df.deserialized
+ .filter((_: Int) => true)
+ .select(l)
.collect()
.run()
.toVector
// otherwise it uses local relation
- val localElems = df.select(l)
- .collect()
- .run()
- .toVector
+ val localElems = df.select(l).collect().run().toVector
val expected = Vector(value)
@@ -56,23 +60,24 @@ class LitTests extends TypedDatasetSuite with Matchers {
}
test("support value class") {
- val initial = Seq(
- Q(name = new Name("Foo"), id = 1),
- Q(name = new Name("Bar"), id = 2))
+ val initial =
+ Seq(Q(name = new Name("Foo"), id = 1), Q(name = new Name("Bar"), id = 2))
val ds = TypedDataset.create(initial)
ds.collect.run() shouldBe initial
val lorem = new Name("Lorem")
- ds.withColumnReplaced('name, functions.litValue(lorem)).
- collect.run() shouldBe initial.map(_.copy(name = lorem))
+ ds.withColumnReplaced('name, functions.litValue(lorem))
+ .collect
+ .run() shouldBe initial.map(_.copy(name = lorem))
}
test("support optional value class") {
val initial = Seq(
R(name = "Foo", id = 1, alias = None),
- R(name = "Bar", id = 2, alias = Some(new Name("Lorem"))))
+ R(name = "Bar", id = 2, alias = Some(new Name("Lorem")))
+ )
val ds = TypedDataset.create(initial)
ds.collect.run() shouldBe initial
@@ -82,13 +87,13 @@ class LitTests extends TypedDatasetSuite with Matchers {
val lit = functions.litValue(someIpsum)
val tds = ds.withColumnReplaced('alias, functions.litValue(someIpsum))
- tds.queryExecution.toString() should include (lit.toString)
+ tds.queryExecution.toString() should include(lit.toString)
- tds.
- collect.run() shouldBe initial.map(_.copy(alias = someIpsum))
+ tds.collect.run() shouldBe initial.map(_.copy(alias = someIpsum))
- ds.withColumnReplaced('alias, functions.litValue(Option.empty[Name])).
- collect.run() shouldBe initial.map(_.copy(alias = None))
+ ds.withColumnReplaced('alias, functions.litValue(Option.empty[Name]))
+ .collect
+ .run() shouldBe initial.map(_.copy(alias = None))
}
test("#205: comparing literals encoded using Injection") {
diff --git a/dataset/src/test/scala/frameless/NumericTests.scala b/dataset/src/test/scala/frameless/NumericTests.scala
index 0c13ae5a3..7faf8d7b4 100644
--- a/dataset/src/test/scala/frameless/NumericTests.scala
+++ b/dataset/src/test/scala/frameless/NumericTests.scala
@@ -1,7 +1,7 @@
package frameless
import org.apache.spark.sql.Encoder
-import org.scalacheck.{Arbitrary, Gen, Prop}
+import org.scalacheck.{ Arbitrary, Gen, Prop }
import org.scalacheck.Prop._
import org.scalatest.matchers.should.Matchers
@@ -43,7 +43,10 @@ class NumericTests extends TypedDatasetSuite with Matchers {
}
test("multiply") {
- def prop[A: TypedEncoder : CatalystNumeric : Numeric : ClassTag](a: A, b: A): Prop = {
+ def prop[A: TypedEncoder: CatalystNumeric: Numeric: ClassTag](
+ a: A,
+ b: A
+ ): Prop = {
val df = TypedDataset.create(X2(a, b) :: Nil)
val result = implicitly[Numeric[A]].times(a, b)
val got = df.select(df.col('a) * df.col('b)).collect().run()
@@ -59,27 +62,36 @@ class NumericTests extends TypedDatasetSuite with Matchers {
}
test("divide") {
- def prop[A: TypedEncoder: CatalystNumeric: Numeric](a: A, b: A)(implicit cd: CatalystDivisible[A, Double]): Prop = {
+ def prop[A: TypedEncoder: CatalystNumeric: Numeric](
+ a: A,
+ b: A
+ )(implicit
+ cd: CatalystDivisible[A, Double]
+ ): Prop = {
val df = TypedDataset.create(X2(a, b) :: Nil)
- if (b == 0) proved else {
- val div: Double = implicitly[Numeric[A]].toDouble(a) / implicitly[Numeric[A]].toDouble(b)
- val got: Seq[Double] = df.select(df.col('a) / df.col('b)).collect().run()
+ if (b == 0) proved
+ else {
+ val div: Double = implicitly[Numeric[A]]
+ .toDouble(a) / implicitly[Numeric[A]].toDouble(b)
+ val got: Seq[Double] =
+ df.select(df.col('a) / df.col('b)).collect().run()
got ?= (div :: Nil)
}
}
- check(prop[Byte ] _)
+ check(prop[Byte] _)
check(prop[Double] _)
- check(prop[Int ] _)
- check(prop[Long ] _)
- check(prop[Short ] _)
+ check(prop[Int] _)
+ check(prop[Long] _)
+ check(prop[Short] _)
}
test("divide BigDecimals") {
def prop(a: BigDecimal, b: BigDecimal): Prop = {
val df = TypedDataset.create(X2(a, b) :: Nil)
- if (b.doubleValue == 0) proved else {
+ if (b.doubleValue == 0) proved
+ else {
// Spark performs something in between Double division and BigDecimal division,
// we approximate it using double vision and `approximatelyEqual`:
val div = BigDecimal(a.doubleValue / b.doubleValue)
@@ -107,24 +119,31 @@ class NumericTests extends TypedDatasetSuite with Matchers {
}
object NumericMod {
+
implicit val byteInstance = new NumericMod[Byte] {
def mod(a: Byte, b: Byte) = (a % b).toByte
}
+
implicit val doubleInstance = new NumericMod[Double] {
def mod(a: Double, b: Double) = a % b
}
+
implicit val floatInstance = new NumericMod[Float] {
def mod(a: Float, b: Float) = a % b
}
+
implicit val intInstance = new NumericMod[Int] {
def mod(a: Int, b: Int) = a % b
}
+
implicit val longInstance = new NumericMod[Long] {
def mod(a: Long, b: Long) = a % b
}
+
implicit val shortInstance = new NumericMod[Short] {
def mod(a: Short, b: Short) = (a % b).toShort
}
+
implicit val bigDecimalInstance = new NumericMod[BigDecimal] {
def mod(a: BigDecimal, b: BigDecimal) = a % b
}
@@ -133,9 +152,10 @@ class NumericTests extends TypedDatasetSuite with Matchers {
test("mod") {
import NumericMod._
- def prop[A: TypedEncoder : CatalystNumeric : NumericMod](a: A, b: A): Prop = {
+ def prop[A: TypedEncoder: CatalystNumeric: NumericMod](a: A, b: A): Prop = {
val df = TypedDataset.create(X2(a, b) :: Nil)
- if (b == 0) proved else {
+ if (b == 0) proved
+ else {
val mod: A = implicitly[NumericMod[A]].mod(a, b)
val got: Seq[A] = df.select(df.col('a) % df.col('b)).collect().run()
@@ -145,19 +165,23 @@ class NumericTests extends TypedDatasetSuite with Matchers {
check(prop[Byte] _)
check(prop[Double] _)
- check(prop[Int ] _)
- check(prop[Long ] _)
- check(prop[Short ] _)
+ check(prop[Int] _)
+ check(prop[Long] _)
+ check(prop[Short] _)
check(prop[BigDecimal] _)
}
- test("a mod lit(b)"){
+ test("a mod lit(b)") {
import NumericMod._
- def prop[A: TypedEncoder : CatalystNumeric : NumericMod](elem: A, data: X1[A]): Prop = {
+ def prop[A: TypedEncoder: CatalystNumeric: NumericMod](
+ elem: A,
+ data: X1[A]
+ ): Prop = {
val dataset = TypedDataset.create(Seq(data))
val a = dataset.col('a)
- if (elem == 0) proved else {
+ if (elem == 0) proved
+ else {
val mod: A = implicitly[NumericMod[A]].mod(data.a, elem)
val got: Seq[A] = dataset.select(a % elem).collect().run()
@@ -167,9 +191,9 @@ class NumericTests extends TypedDatasetSuite with Matchers {
check(prop[Byte] _)
check(prop[Double] _)
- check(prop[Int ] _)
- check(prop[Long ] _)
- check(prop[Short ] _)
+ check(prop[Int] _)
+ check(prop[Long] _)
+ check(prop[Short] _)
check(prop[BigDecimal] _)
}
@@ -180,12 +204,13 @@ class NumericTests extends TypedDatasetSuite with Matchers {
implicit val doubleWithNaN = Arbitrary {
implicitly[Arbitrary[Double]].arbitrary.flatMap(Gen.oneOf(_, Double.NaN))
}
- implicit val x1 = Arbitrary{ doubleWithNaN.arbitrary.map(X1(_)) }
+ implicit val x1 = Arbitrary { doubleWithNaN.arbitrary.map(X1(_)) }
- def prop[A : TypedEncoder : Encoder : CatalystNaN](data: List[X1[A]]): Prop = {
+ def prop[A: TypedEncoder: Encoder: CatalystNaN](data: List[X1[A]]): Prop = {
val ds = TypedDataset.create(data)
- val expected = ds.toDF().filter(!$"a".isNaN).map(_.getAs[A](0)).collect().toSeq
+ val expected =
+ ds.toDF().filter(!$"a".isNaN).map(_.getAs[A](0)).collect().toSeq
val rs = ds.filter(!ds('a).isNaN).collect().run().map(_.a)
rs ?= expected
diff --git a/dataset/src/test/scala/frameless/OrderByTests.scala b/dataset/src/test/scala/frameless/OrderByTests.scala
index 98bd7442d..2b07fa860 100644
--- a/dataset/src/test/scala/frameless/OrderByTests.scala
+++ b/dataset/src/test/scala/frameless/OrderByTests.scala
@@ -7,19 +7,28 @@ import org.apache.spark.sql.Column
import org.scalatest.matchers.should.Matchers
class OrderByTests extends TypedDatasetSuite with Matchers {
- def sortings[A : CatalystOrdered, T]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] = Seq(
- (_.desc, _.desc),
- (_.asc, _.asc),
- (t => t, t => t) //default ascending
- )
+
+ def sortings[
+ A: CatalystOrdered,
+ T
+ ]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] =
+ Seq(
+ (_.desc, _.desc),
+ (_.asc, _.asc),
+ (t => t, t => t) // default ascending
+ )
test("single column non nullable orderBy") {
- def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = {
+ def prop[A: TypedEncoder: CatalystOrdered](data: Vector[X1[A]]): Prop = {
val ds = TypedDataset.create(data)
- sortings[A, X1[A]].map { case (typ, untyp) =>
- ds.dataset.orderBy(untyp(ds.dataset.col("a"))).collect().toVector.?=(
- ds.orderBy(typ(ds('a))).collect().run().toVector)
+ sortings[A, X1[A]].map {
+ case (typ, untyp) =>
+ ds.dataset
+ .orderBy(untyp(ds.dataset.col("a")))
+ .collect()
+ .toVector
+ .?=(ds.orderBy(typ(ds('a))).collect().run().toVector)
}.reduce(_ && _)
}
@@ -36,12 +45,16 @@ class OrderByTests extends TypedDatasetSuite with Matchers {
}
test("single column non nullable partition sorting") {
- def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = {
+ def prop[A: TypedEncoder: CatalystOrdered](data: Vector[X1[A]]): Prop = {
val ds = TypedDataset.create(data)
- sortings[A, X1[A]].map { case (typ, untyp) =>
- ds.dataset.sortWithinPartitions(untyp(ds.dataset.col("a"))).collect().toVector.?=(
- ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector)
+ sortings[A, X1[A]].map {
+ case (typ, untyp) =>
+ ds.dataset
+ .sortWithinPartitions(untyp(ds.dataset.col("a")))
+ .collect()
+ .toVector
+ .?=(ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector)
}.reduce(_ && _)
}
@@ -58,15 +71,34 @@ class OrderByTests extends TypedDatasetSuite with Matchers {
}
test("two columns non nullable orderBy") {
- def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = {
+ def prop[
+ A: TypedEncoder: CatalystOrdered,
+ B: TypedEncoder: CatalystOrdered
+ ](data: Vector[X2[A, B]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) =>
- val vanillaSpark = ds.dataset.orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector
- vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&(
- vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b))).collect().run().toVector
- )
- }.reduce(_ && _)
+ sortings[A, X2[A, B]].reverse
+ .zip(sortings[B, X2[A, B]])
+ .map {
+ case ((typA, untypA), (typB, untypB)) =>
+ val vanillaSpark = ds.dataset
+ .orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")))
+ .collect()
+ .toVector
+ vanillaSpark
+ .?=(
+ ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector
+ )
+ .&&(
+ vanillaSpark ?= ds
+ .orderByMany(typA(ds('a)), typB(ds('b)))
+ .collect()
+ .run()
+ .toVector
+ )
+ }
+ .reduce(_ && _)
}
check(forAll(prop[SQLDate, Long] _))
@@ -75,15 +107,40 @@ class OrderByTests extends TypedDatasetSuite with Matchers {
}
test("two columns non nullable partition sorting") {
- def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = {
+ def prop[
+ A: TypedEncoder: CatalystOrdered,
+ B: TypedEncoder: CatalystOrdered
+ ](data: Vector[X2[A, B]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) =>
- val vanillaSpark = ds.dataset.sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector
- vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&(
- vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b))).collect().run().toVector
- )
- }.reduce(_ && _)
+ sortings[A, X2[A, B]].reverse
+ .zip(sortings[B, X2[A, B]])
+ .map {
+ case ((typA, untypA), (typB, untypB)) =>
+ val vanillaSpark = ds.dataset
+ .sortWithinPartitions(
+ untypA(ds.dataset.col("a")),
+ untypB(ds.dataset.col("b"))
+ )
+ .collect()
+ .toVector
+ vanillaSpark
+ .?=(
+ ds.sortWithinPartitions(typA(ds('a)), typB(ds('b)))
+ .collect()
+ .run()
+ .toVector
+ )
+ .&&(
+ vanillaSpark ?= ds
+ .sortWithinPartitionsMany(typA(ds('a)), typB(ds('b)))
+ .collect()
+ .run()
+ .toVector
+ )
+ }
+ .reduce(_ && _)
}
check(forAll(prop[SQLDate, Long] _))
@@ -92,21 +149,43 @@ class OrderByTests extends TypedDatasetSuite with Matchers {
}
test("three columns non nullable orderBy") {
- def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = {
+ def prop[
+ A: TypedEncoder: CatalystOrdered,
+ B: TypedEncoder: CatalystOrdered
+ ](data: Vector[X3[A, B, A]]
+ ): Prop = {
val ds = TypedDataset.create(data)
sortings[A, X3[A, B, A]].reverse
.zip(sortings[B, X3[A, B, A]])
.zip(sortings[A, X3[A, B, A]])
- .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) =>
- val vanillaSpark = ds.dataset
- .orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c")))
- .collect().toVector
-
- vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&(
- vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector
- )
- }.reduce(_ && _)
+ .map {
+ case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) =>
+ val vanillaSpark = ds.dataset
+ .orderBy(
+ untypA(ds.dataset.col("a")),
+ untypB(ds.dataset.col("b")),
+ untypA2(ds.dataset.col("c"))
+ )
+ .collect()
+ .toVector
+
+ vanillaSpark
+ .?=(
+ ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c)))
+ .collect()
+ .run()
+ .toVector
+ )
+ .&&(
+ vanillaSpark ?= ds
+ .orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c)))
+ .collect()
+ .run()
+ .toVector
+ )
+ }
+ .reduce(_ && _)
}
check(forAll(prop[SQLDate, Long] _))
@@ -115,21 +194,50 @@ class OrderByTests extends TypedDatasetSuite with Matchers {
}
test("three columns non nullable partition sorting") {
- def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = {
+ def prop[
+ A: TypedEncoder: CatalystOrdered,
+ B: TypedEncoder: CatalystOrdered
+ ](data: Vector[X3[A, B, A]]
+ ): Prop = {
val ds = TypedDataset.create(data)
sortings[A, X3[A, B, A]].reverse
.zip(sortings[B, X3[A, B, A]])
.zip(sortings[A, X3[A, B, A]])
- .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) =>
- val vanillaSpark = ds.dataset
- .sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c")))
- .collect().toVector
-
- vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&(
- vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector
- )
- }.reduce(_ && _)
+ .map {
+ case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) =>
+ val vanillaSpark = ds.dataset
+ .sortWithinPartitions(
+ untypA(ds.dataset.col("a")),
+ untypB(ds.dataset.col("b")),
+ untypA2(ds.dataset.col("c"))
+ )
+ .collect()
+ .toVector
+
+ vanillaSpark
+ .?=(
+ ds.sortWithinPartitions(
+ typA(ds('a)),
+ typB(ds('b)),
+ typA2(ds('c))
+ ).collect()
+ .run()
+ .toVector
+ )
+ .&&(
+ vanillaSpark ?= ds
+ .sortWithinPartitionsMany(
+ typA(ds('a)),
+ typB(ds('b)),
+ typA2(ds('c))
+ )
+ .collect()
+ .run()
+ .toVector
+ )
+ }
+ .reduce(_ && _)
}
check(forAll(prop[SQLDate, Long] _))
@@ -138,13 +246,28 @@ class OrderByTests extends TypedDatasetSuite with Matchers {
}
test("sort support for mixed default and explicit ordering") {
- def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A, B]]): Prop = {
+ def prop[
+ A: TypedEncoder: CatalystOrdered,
+ B: TypedEncoder: CatalystOrdered
+ ](data: Vector[X2[A, B]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- ds.dataset.orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=(
- ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) &&
- ds.dataset.sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=(
- ds.sortWithinPartitionsMany(ds('a), ds('b).desc).collect().run().toVector)
+ ds.dataset
+ .orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc)
+ .collect()
+ .toVector
+ .?=(ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) &&
+ ds.dataset
+ .sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc)
+ .collect()
+ .toVector
+ .?=(
+ ds.sortWithinPartitionsMany(ds('a), ds('b).desc)
+ .collect()
+ .run()
+ .toVector
+ )
}
check(forAll(prop[SQLDate, Long] _))
@@ -159,50 +282,67 @@ class OrderByTests extends TypedDatasetSuite with Matchers {
illTyped("""d.sortWithinPartitions(d('b).desc)""")
}
- test("derives a CatalystOrdered for case classes when all fields are comparable") {
+ test(
+ "derives a CatalystOrdered for case classes when all fields are comparable"
+ ) {
type T[A, B] = X3[Int, Boolean, X2[A, B]]
def prop[
- A: TypedEncoder : CatalystOrdered,
- B: TypedEncoder : CatalystOrdered
- ](data: Vector[T[A, B]]): Prop = {
+ A: TypedEncoder: CatalystOrdered,
+ B: TypedEncoder: CatalystOrdered
+ ](data: Vector[T[A, B]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- sortings[X2[A, B], T[A, B]].map { case (typX2, untypX2) =>
- val vanilla = ds.dataset.orderBy(untypX2(ds.dataset.col("c"))).collect().toVector
- val frameless = ds.orderBy(typX2(ds('c))).collect().run.toVector
- vanilla ?= frameless
+ sortings[X2[A, B], T[A, B]].map {
+ case (typX2, untypX2) =>
+ val vanilla =
+ ds.dataset.orderBy(untypX2(ds.dataset.col("c"))).collect().toVector
+ val frameless = ds.orderBy(typX2(ds('c))).collect().run.toVector
+ vanilla ?= frameless
}.reduce(_ && _)
}
check(forAll(prop[Int, Long] _))
check(forAll(prop[(String, SQLDate), Float] _))
// Check that nested case classes are properly derived too
- check(forAll(prop[X2[Boolean, Float], X4[SQLTimestamp, Double, Short, Byte]] _))
+ check(
+ forAll(prop[X2[Boolean, Float], X4[SQLTimestamp, Double, Short, Byte]] _)
+ )
}
test("derives a CatalystOrdered for tuples when all fields are comparable") {
type T[A, B] = X2[Int, (A, B)]
def prop[
- A: TypedEncoder : CatalystOrdered,
- B: TypedEncoder : CatalystOrdered
- ](data: Vector[T[A, B]]): Prop = {
+ A: TypedEncoder: CatalystOrdered,
+ B: TypedEncoder: CatalystOrdered
+ ](data: Vector[T[A, B]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- sortings[(A, B), T[A, B]].map { case (typX2, untypX2) =>
- val vanilla = ds.dataset.orderBy(untypX2(ds.dataset.col("b"))).collect().toVector
- val frameless = ds.orderBy(typX2(ds('b))).collect().run.toVector
- vanilla ?= frameless
+ sortings[(A, B), T[A, B]].map {
+ case (typX2, untypX2) =>
+ val vanilla =
+ ds.dataset.orderBy(untypX2(ds.dataset.col("b"))).collect().toVector
+ val frameless = ds.orderBy(typX2(ds('b))).collect().run.toVector
+ vanilla ?= frameless
}.reduce(_ && _)
}
check(forAll(prop[Int, Long] _))
check(forAll(prop[(String, SQLDate), Float] _))
- check(forAll(prop[X2[Boolean, Float], X1[(SQLTimestamp, Double, Short, Byte)]] _))
+ check(
+ forAll(
+ prop[X2[Boolean, Float], X1[(SQLTimestamp, Double, Short, Byte)]] _
+ )
+ )
}
test("fails to compile when one of the field isn't comparable") {
type T = X2[Int, X2[Int, Map[String, String]]]
val d = TypedDataset.create(X2(1, X2(2, Map("not" -> "comparable"))) :: Nil)
- illTyped("d.orderBy(d('b).desc)", """Cannot compare columns of type frameless.X2\[Int,scala.collection.immutable.Map\[String,String]].""")
+ illTyped(
+ "d.orderBy(d('b).desc)",
+ """Cannot compare columns of type frameless.X2\[Int,scala.collection.immutable.Map\[String,String]]."""
+ )
}
}
diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala
index 98274cf01..121785371 100644
--- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala
+++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala
@@ -1,6 +1,6 @@
package frameless
-import org.apache.spark.sql.{Row, functions => F}
+import org.apache.spark.sql.{ Row, functions => F }
import org.apache.spark.sql.types.{
ArrayType,
BinaryType,
@@ -14,7 +14,7 @@ import org.apache.spark.sql.types.{
StructType
}
-import shapeless.{HList, LabelledGeneric}
+import shapeless.{ HList, LabelledGeneric }
import shapeless.test.illTyped
import org.scalatest.matchers.should.Matchers
@@ -25,19 +25,31 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
}
test("Dropping fields") {
- def dropUnitValues[L <: HList](l: L)(implicit d: DropUnitValues[L]): d.Out = d(l)
- val fields = LabelledGeneric[TupleWithUnits].to(TupleWithUnits(42, "something"))
- dropUnitValues(fields) shouldEqual LabelledGeneric[(Int, String)].to((42, "something"))
+ def dropUnitValues[L <: HList](
+ l: L
+ )(implicit
+ d: DropUnitValues[L]
+ ): d.Out = d(l)
+ val fields =
+ LabelledGeneric[TupleWithUnits].to(TupleWithUnits(42, "something"))
+ dropUnitValues(fields) shouldEqual LabelledGeneric[(Int, String)]
+ .to((42, "something"))
}
test("Representation skips units") {
- assert(TypedEncoder[(Int, String)].catalystRepr == TypedEncoder[TupleWithUnits].catalystRepr)
+ assert(
+ TypedEncoder[(Int, String)].catalystRepr == TypedEncoder[
+ TupleWithUnits
+ ].catalystRepr
+ )
}
test("Serialization skips units") {
val df = session.createDataFrame(Seq((1, "one"), (2, "two")))
val ds = df.as[TupleWithUnits](TypedExpressionEncoder[TupleWithUnits])
- val tds = TypedDataset.create(Seq(TupleWithUnits(1, "one"), TupleWithUnits(2, "two")))
+ val tds = TypedDataset.create(
+ Seq(TupleWithUnits(1, "one"), TupleWithUnits(2, "two"))
+ )
df.collect shouldEqual tds.toDF.collect
ds.collect.toSeq shouldEqual tds.collect.run
@@ -51,7 +63,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
test("Empty nested record value becomes none on deserialization") {
val rdd = sc.parallelize(Seq(Row(null)))
- val schema = TypedEncoder[OptionalNesting].catalystRepr.asInstanceOf[StructType]
+ val schema =
+ TypedEncoder[OptionalNesting].catalystRepr.asInstanceOf[StructType]
val df = session.createDataFrame(rdd, schema)
val ds = TypedDataset.createUnsafe(df)(TypedEncoder[OptionalNesting])
@@ -61,7 +74,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
test("Deeply nested optional values have correct deserialization") {
val rdd = sc.parallelize(Seq(Row(true, Row(null, null))))
type NestedOptionPair = X2[Boolean, Option[X2[Option[Int], Option[String]]]]
- val schema = TypedEncoder[NestedOptionPair].catalystRepr.asInstanceOf[StructType]
+ val schema =
+ TypedEncoder[NestedOptionPair].catalystRepr.asInstanceOf[StructType]
val df = session.createDataFrame(rdd, schema)
val ds = TypedDataset.createUnsafe(df)(TypedEncoder[NestedOptionPair])
ds.firstOption.run.get shouldBe X2(true, Some(X2(None, None)))
@@ -95,14 +109,16 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
encoder.jvmRepr shouldBe ObjectType(classOf[Name])
encoder.catalystRepr shouldBe StructType(
- Seq(StructField("value", StringType, false)))
+ Seq(StructField("value", StringType, false))
+ )
val sqlContext = session.sqlContext
import sqlContext.implicits._
TypedDataset
.createUnsafe[Name](Seq("Foo", "Bar").toDF)(encoder)
- .collect().run() shouldBe Seq(new Name("Foo"), new Name("Bar"))
+ .collect()
+ .run() shouldBe Seq(new Name("Foo"), new Name("Bar"))
}
@@ -111,7 +127,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
illTyped(
// As `Person` is not a Value class
- "val _: RecordFieldEncoder[Person] = RecordFieldEncoder.valueClass")
+ "val _: RecordFieldEncoder[Person] = RecordFieldEncoder.valueClass"
+ )
val fieldEncoder: RecordFieldEncoder[Name] = RecordFieldEncoder.valueClass
@@ -123,24 +140,28 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
encoder.jvmRepr shouldBe ObjectType(classOf[Person])
- val expectedPersonStructType = StructType(Seq(
- StructField("name", StringType, false),
- StructField("age", IntegerType, false)))
+ val expectedPersonStructType = StructType(
+ Seq(
+ StructField("name", StringType, false),
+ StructField("age", IntegerType, false)
+ )
+ )
encoder.catalystRepr shouldBe expectedPersonStructType
val unsafeDs: TypedDataset[Person] = {
- val rdd = sc.parallelize(Seq(
- Row.fromTuple("Foo" -> 2),
- Row.fromTuple("Bar" -> 3)
- ))
+ val rdd = sc.parallelize(
+ Seq(
+ Row.fromTuple("Foo" -> 2),
+ Row.fromTuple("Bar" -> 3)
+ )
+ )
val df = session.createDataFrame(rdd, expectedPersonStructType)
TypedDataset.createUnsafe(df)(encoder)
}
- val expected = Seq(
- Person(new Name("Foo"), 2), Person(new Name("Bar"), 3))
+ val expected = Seq(Person(new Name("Foo"), 2), Person(new Name("Bar"), 3))
unsafeDs.collect.run() shouldBe expected
@@ -151,8 +172,10 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val lorem = new Name("Lorem")
- safeDs.withColumnReplaced('name, functions.litValue(lorem)).
- collect.run() shouldBe expected.map(_.copy(name = lorem))
+ safeDs
+ .withColumnReplaced('name, functions.litValue(lorem))
+ .collect
+ .run() shouldBe expected.map(_.copy(name = lorem))
}
test("Case class with value class as optional field") {
@@ -160,7 +183,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
illTyped( // As `Person` is not a Value class
"""val _: RecordFieldEncoder[Option[Person]] =
- RecordFieldEncoder.optionValueClass""")
+ RecordFieldEncoder.optionValueClass"""
+ )
val fieldEncoder: RecordFieldEncoder[Option[Name]] =
RecordFieldEncoder.optionValueClass
@@ -168,33 +192,37 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
fieldEncoder.encoder.catalystRepr shouldBe StringType
fieldEncoder.encoder. // !StringType
- jvmRepr shouldBe ObjectType(classOf[Option[_]])
+ jvmRepr shouldBe ObjectType(classOf[Option[_]])
// Encode as a Person field
val encoder = TypedEncoder[User]
encoder.jvmRepr shouldBe ObjectType(classOf[User])
- val expectedPersonStructType = StructType(Seq(
- StructField("id", LongType, false),
- StructField("name", StringType, true)))
+ val expectedPersonStructType = StructType(
+ Seq(
+ StructField("id", LongType, false),
+ StructField("name", StringType, true)
+ )
+ )
encoder.catalystRepr shouldBe expectedPersonStructType
val ds1: TypedDataset[User] = {
- val rdd = sc.parallelize(Seq(
- Row(1L, null),
- Row(2L, "Foo")
- ))
+ val rdd = sc.parallelize(
+ Seq(
+ Row(1L, null),
+ Row(2L, "Foo")
+ )
+ )
val df = session.createDataFrame(rdd, expectedPersonStructType)
TypedDataset.createUnsafe(df)(encoder)
}
- ds1.collect.run() shouldBe Seq(
- User(1L, None),
- User(2L, Some(new Name("Foo"))))
+ ds1.collect
+ .run() shouldBe Seq(User(1L, None), User(2L, Some(new Name("Foo"))))
val ds2: TypedDataset[User] = {
val sqlContext = session.sqlContext
@@ -206,18 +234,18 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
"""{"id":5,"name":null}"""
).toDF
- val df2 = df1.withColumn(
- "jsonValue",
- F.from_json(df1.col("value"), expectedPersonStructType)).
- select("jsonValue.id", "jsonValue.name")
+ val df2 = df1
+ .withColumn(
+ "jsonValue",
+ F.from_json(df1.col("value"), expectedPersonStructType)
+ )
+ .select("jsonValue.id", "jsonValue.name")
TypedDataset.createUnsafe[User](df2)
}
- val expected = Seq(
- User(3L, None),
- User(4L, Some(new Name("Lorem"))),
- User(5L, None))
+ val expected =
+ Seq(User(3L, None), User(4L, Some(new Name("Lorem"))), User(5L, None))
ds2.collect.run() shouldBe expected
@@ -232,11 +260,19 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
encoder.jvmRepr shouldBe ObjectType(classOf[D])
- val expectedStructType = StructType(Seq(
- StructField("m", MapType(
- keyType = StringType,
- valueType = IntegerType,
- valueContainsNull = false), false)))
+ val expectedStructType = StructType(
+ Seq(
+ StructField(
+ "m",
+ MapType(
+ keyType = StringType,
+ valueType = IntegerType,
+ valueContainsNull = false
+ ),
+ false
+ )
+ )
+ )
encoder.catalystRepr shouldBe expectedStructType
@@ -246,18 +282,19 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val ds1 = TypedDataset.createUnsafe[D] {
val df = Seq(
"""{"m":{"pizza":1,"sushi":2}}""",
- """{"m":{"red":3,"blue":4}}""",
+ """{"m":{"red":3,"blue":4}}"""
).toDF
df.withColumn(
"jsonValue",
- F.from_json(df.col("value"), expectedStructType)).
- select("jsonValue.*")
+ F.from_json(df.col("value"), expectedStructType)
+ ).select("jsonValue.*")
}
val expected = Seq(
D(m = Map("pizza" -> 1, "sushi" -> 2)),
- D(m = Map("red" -> 3, "blue" -> 4)))
+ D(m = Map("red" -> 3, "blue" -> 4))
+ )
ds1.collect.run() shouldBe expected
@@ -275,12 +312,20 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
encoder.jvmRepr shouldBe ObjectType(classOf[Student])
- val expectedStudentStructType = StructType(Seq(
- StructField("name", StringType, false),
- StructField("grades", MapType(
- keyType = StringType,
- valueType = DecimalType.SYSTEM_DEFAULT,
- valueContainsNull = false), false)))
+ val expectedStudentStructType = StructType(
+ Seq(
+ StructField("name", StringType, false),
+ StructField(
+ "grades",
+ MapType(
+ keyType = StringType,
+ valueType = DecimalType.SYSTEM_DEFAULT,
+ valueContainsNull = false
+ ),
+ false
+ )
+ )
+ )
encoder.catalystRepr shouldBe expectedStudentStructType
@@ -290,51 +335,65 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val ds1 = TypedDataset.createUnsafe[Student] {
val df = Seq(
"""{"name":"Foo","grades":{"math":1,"physics":"23.4"}}""",
- """{"name":"Bar","grades":{"biology":18.5,"geography":4}}""",
+ """{"name":"Bar","grades":{"biology":18.5,"geography":4}}"""
).toDF
df.withColumn(
"jsonValue",
- F.from_json(df.col("value"), expectedStudentStructType)).
- select("jsonValue.*")
+ F.from_json(df.col("value"), expectedStudentStructType)
+ ).select("jsonValue.*")
}
val expected = Seq(
- Student(name = "Foo", grades = Map(
- new Subject("math") -> new Grade(BigDecimal(1)),
- new Subject("physics") -> new Grade(BigDecimal(23.4D)))),
- Student(name = "Bar", grades = Map(
- new Subject("biology") -> new Grade(BigDecimal(18.5)),
- new Subject("geography") -> new Grade(BigDecimal(4L)))))
+ Student(
+ name = "Foo",
+ grades = Map(
+ new Subject("math") -> new Grade(BigDecimal(1)),
+ new Subject("physics") -> new Grade(BigDecimal(23.4D))
+ )
+ ),
+ Student(
+ name = "Bar",
+ grades = Map(
+ new Subject("biology") -> new Grade(BigDecimal(18.5)),
+ new Subject("geography") -> new Grade(BigDecimal(4L))
+ )
+ )
+ )
ds1.collect.run() shouldBe expected
val grades = Map[Subject, Grade](
- new Subject("any") -> new Grade(BigDecimal(Long.MaxValue) + 1L))
+ new Subject("any") -> new Grade(BigDecimal(Long.MaxValue) + 1L)
+ )
val ds2 = ds1.withColumnReplaced('grades, functions.lit(grades))
- ds2.collect.run() shouldBe Seq(
- Student("Foo", grades), Student("Bar", grades))
+ ds2.collect
+ .run() shouldBe Seq(Student("Foo", grades), Student("Bar", grades))
}
test("Encode binary array") {
val encoder = TypedEncoder[Tuple2[String, Array[Byte]]]
- encoder.jvmRepr shouldBe ObjectType(
- classOf[Tuple2[String, Array[Byte]]])
+ encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[String, Array[Byte]]])
- val expectedStructType = StructType(Seq(
- StructField("_1", StringType, false),
- StructField("_2", BinaryType, false)))
+ val expectedStructType = StructType(
+ Seq(
+ StructField("_1", StringType, false),
+ StructField("_2", BinaryType, false)
+ )
+ )
encoder.catalystRepr shouldBe expectedStructType
val ds1: TypedDataset[(String, Array[Byte])] = {
- val rdd = sc.parallelize(Seq(
- Row.fromTuple("Foo" -> Array[Byte](3, 4)),
- Row.fromTuple("Bar" -> Array[Byte](5))
- ))
+ val rdd = sc.parallelize(
+ Seq(
+ Row.fromTuple("Foo" -> Array[Byte](3, 4)),
+ Row.fromTuple("Bar" -> Array[Byte](5))
+ )
+ )
val df = session.createDataFrame(rdd, expectedStructType)
TypedDataset.createUnsafe(df)(encoder)
@@ -342,28 +401,27 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val expected = Seq("Foo" -> Seq[Byte](3, 4), "Bar" -> Seq[Byte](5))
- ds1.collect.run().map {
- case (_1, _2) => _1 -> _2.toSeq
- } shouldBe expected
+ ds1.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected
val subjects = "lorem".getBytes("UTF-8").toSeq
val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray))
- ds2.collect.run().map {
- case (_1, _2) => _1 -> _2.toSeq
- } shouldBe expected.map(_.copy(_2 = subjects))
+ ds2.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected
+ .map(_.copy(_2 = subjects))
}
test("Encode simple array") {
val encoder = TypedEncoder[Tuple2[String, Array[Int]]]
- encoder.jvmRepr shouldBe ObjectType(
- classOf[Tuple2[String, Array[Int]]])
+ encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[String, Array[Int]]])
- val expectedStructType = StructType(Seq(
- StructField("_1", StringType, false),
- StructField("_2", ArrayType(IntegerType, false), false)))
+ val expectedStructType = StructType(
+ Seq(
+ StructField("_1", StringType, false),
+ StructField("_2", ArrayType(IntegerType, false), false)
+ )
+ )
encoder.catalystRepr shouldBe expectedStructType
@@ -373,28 +431,25 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val ds1 = TypedDataset.createUnsafe[(String, Array[Int])] {
val df = Seq(
"""{"_1":"Foo", "_2":[3, 4]}""",
- """{"_1":"Bar", "_2":[5]}""",
+ """{"_1":"Bar", "_2":[5]}"""
).toDF
df.withColumn(
"jsonValue",
- F.from_json(df.col("value"), expectedStructType)).
- select("jsonValue.*")
+ F.from_json(df.col("value"), expectedStructType)
+ ).select("jsonValue.*")
}
val expected = Seq("Foo" -> Seq(3, 4), "Bar" -> Seq(5))
- ds1.collect.run().map {
- case (_1, _2) => _1 -> _2.toSeq
- } shouldBe expected
+ ds1.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected
val subjects = Seq(6, 6, 7)
val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray))
- ds2.collect.run().map {
- case (_1, _2) => _1 -> _2.toSeq
- } shouldBe expected.map(_.copy(_2 = subjects))
+ ds2.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected
+ .map(_.copy(_2 = subjects))
}
test("Encode array of Value class") {
@@ -402,12 +457,14 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val encoder = TypedEncoder[Tuple2[String, Array[Subject]]]
- encoder.jvmRepr shouldBe ObjectType(
- classOf[Tuple2[String, Array[Subject]]])
+ encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[String, Array[Subject]]])
- val expectedStructType = StructType(Seq(
- StructField("_1", StringType, false),
- StructField("_2", ArrayType(StringType, false), false)))
+ val expectedStructType = StructType(
+ Seq(
+ StructField("_1", StringType, false),
+ StructField("_2", ArrayType(StringType, false), false)
+ )
+ )
encoder.catalystRepr shouldBe expectedStructType
@@ -417,30 +474,28 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val ds1 = TypedDataset.createUnsafe[(String, Array[Subject])] {
val df = Seq(
"""{"_1":"Foo", "_2":["math","physics"]}""",
- """{"_1":"Bar", "_2":["biology","geography"]}""",
+ """{"_1":"Bar", "_2":["biology","geography"]}"""
).toDF
df.withColumn(
"jsonValue",
- F.from_json(df.col("value"), expectedStructType)).
- select("jsonValue.*")
+ F.from_json(df.col("value"), expectedStructType)
+ ).select("jsonValue.*")
}
val expected = Seq(
"Foo" -> Seq(new Subject("math"), new Subject("physics")),
- "Bar" -> Seq(new Subject("biology"), new Subject("geography")))
+ "Bar" -> Seq(new Subject("biology"), new Subject("geography"))
+ )
- ds1.collect.run().map {
- case (_1, _2) => _1 -> _2.toSeq
- } shouldBe expected
+ ds1.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected
val subjects = Seq(new Subject("lorem"), new Subject("ipsum"))
val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray))
- ds2.collect.run().map {
- case (_1, _2) => _1 -> _2.toSeq
- } shouldBe expected.map(_.copy(_2 = subjects))
+ ds2.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected
+ .map(_.copy(_2 = subjects))
}
test("Encode case class with simple Seq") {
@@ -450,22 +505,41 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
encoder.jvmRepr shouldBe ObjectType(classOf[B])
- val expectedStructType = StructType(Seq(
- StructField("a", ArrayType(StructType(Seq(
- StructField("x", IntegerType, false))), false), false)))
+ val expectedStructType = StructType(
+ Seq(
+ StructField(
+ "a",
+ ArrayType(
+ StructType(Seq(StructField("x", IntegerType, false))),
+ false
+ ),
+ false
+ )
+ )
+ )
encoder.catalystRepr shouldBe expectedStructType
val ds1: TypedDataset[B] = {
- val rdd = sc.parallelize(Seq(
- Row.fromTuple(Tuple1(Seq(
- Row.fromTuple(Tuple1[Int](1)),
- Row.fromTuple(Tuple1[Int](3))
- ))),
- Row.fromTuple(Tuple1(Seq(
- Row.fromTuple(Tuple1[Int](2))
- )))
- ))
+ val rdd = sc.parallelize(
+ Seq(
+ Row.fromTuple(
+ Tuple1(
+ Seq(
+ Row.fromTuple(Tuple1[Int](1)),
+ Row.fromTuple(Tuple1[Int](3))
+ )
+ )
+ ),
+ Row.fromTuple(
+ Tuple1(
+ Seq(
+ Row.fromTuple(Tuple1[Int](2))
+ )
+ )
+ )
+ )
+ )
val df = session.createDataFrame(rdd, expectedStructType)
TypedDataset.createUnsafe(df)(encoder)
@@ -489,9 +563,12 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[Int, Seq[Name]]])
- val expectedStructType = StructType(Seq(
- StructField("_1", IntegerType, false),
- StructField("_2", ArrayType(StringType, false), false)))
+ val expectedStructType = StructType(
+ Seq(
+ StructField("_1", IntegerType, false),
+ StructField("_2", ArrayType(StringType, false), false)
+ )
+ )
encoder.catalystRepr shouldBe expectedStructType
@@ -501,18 +578,19 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
val df = Seq(
"""{"_1":1, "_2":["foo", "bar"]}""",
- """{"_1":2, "_2":["lorem"]}""",
+ """{"_1":2, "_2":["lorem"]}"""
).toDF
df.withColumn(
"jsonValue",
- F.from_json(df.col("value"), expectedStructType)).
- select("jsonValue.*")
+ F.from_json(df.col("value"), expectedStructType)
+ ).select("jsonValue.*")
}
val expected = Seq(
1 -> Seq(new Name("foo"), new Name("bar")),
- 2 -> Seq(new Name("lorem")))
+ 2 -> Seq(new Name("lorem"))
+ )
ds1.collect.run() shouldBe expected
}
@@ -523,9 +601,15 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers {
case class UnitsOnly(a: Unit, b: Unit)
case class TupleWithUnits(
- u0: Unit, _1: Int, u1: Unit, u2: Unit, _2: String, u3: Unit)
+ u0: Unit,
+ _1: Int,
+ u1: Unit,
+ u2: Unit,
+ _2: String,
+ u3: Unit)
object TupleWithUnits {
+
def apply(_1: Int, _2: String): TupleWithUnits =
TupleWithUnits((), _1, (), (), _2, ())
}
diff --git a/dataset/src/test/scala/frameless/SchemaTests.scala b/dataset/src/test/scala/frameless/SchemaTests.scala
index 92fd33057..c93c92827 100644
--- a/dataset/src/test/scala/frameless/SchemaTests.scala
+++ b/dataset/src/test/scala/frameless/SchemaTests.scala
@@ -10,10 +10,13 @@ import org.scalatest.matchers.should.Matchers
class SchemaTests extends TypedDatasetSuite with Matchers {
def structToNonNullable(struct: StructType): StructType = {
- StructType(struct.fields.map( f => f.copy(nullable = false)))
+ StructType(struct.fields.map(f => f.copy(nullable = false)))
}
- def prop[A](dataset: TypedDataset[A], ignoreNullable: Boolean = false): Prop = {
+ def prop[A](
+ dataset: TypedDataset[A],
+ ignoreNullable: Boolean = false
+ ): Prop = {
val schema = dataset.dataset.schema
Prop.all(
@@ -24,7 +27,9 @@ class SchemaTests extends TypedDatasetSuite with Matchers {
if (!ignoreNullable)
TypedExpressionEncoder.targetStructType(dataset.encoder) ?= schema
else
- structToNonNullable(TypedExpressionEncoder.targetStructType(dataset.encoder)) ?= structToNonNullable(schema)
+ structToNonNullable(
+ TypedExpressionEncoder.targetStructType(dataset.encoder)
+ ) ?= structToNonNullable(schema)
)
}
diff --git a/dataset/src/test/scala/frameless/SelectTests.scala b/dataset/src/test/scala/frameless/SelectTests.scala
index 8043fc941..ced762726 100644
--- a/dataset/src/test/scala/frameless/SelectTests.scala
+++ b/dataset/src/test/scala/frameless/SelectTests.scala
@@ -7,12 +7,13 @@ import scala.reflect.ClassTag
class SelectTests extends TypedDatasetSuite {
test("select('a) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
@@ -29,14 +30,15 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a, 'b) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- eab: TypedEncoder[(A, B)],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ eab: TypedEncoder[(A, B)],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -53,15 +55,16 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a, 'b, 'c) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- eab: TypedEncoder[(A, B, C)],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ eab: TypedEncoder[(A, B, C)],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -79,15 +82,16 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a,'b,'c,'d) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- ed: TypedEncoder[D],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ ed: TypedEncoder[D],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val a1 = dataset.col[A]('a)
val a2 = dataset.col[B]('b)
@@ -106,15 +110,16 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a,'b,'c,'d,'a) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- ed: TypedEncoder[D],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ ed: TypedEncoder[D],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val a1 = dataset.col[A]('a)
val a2 = dataset.col[B]('b)
@@ -133,22 +138,24 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a,'b,'c,'d,'a, 'c) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- ed: TypedEncoder[D],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ ed: TypedEncoder[D],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val a1 = dataset.col[A]('a)
val a2 = dataset.col[B]('b)
val a3 = dataset.col[C]('c)
val a4 = dataset.col[D]('d)
- val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3).collect().run().toVector
+ val dataset2 =
+ dataset.select(a1, a2, a3, a4, a1, a3).collect().run().toVector
val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c) }
dataset2 ?= data2
@@ -160,22 +167,24 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a,'b,'c,'d,'a,'c,'b) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- ed: TypedEncoder[D],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ ed: TypedEncoder[D],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val a1 = dataset.col[A]('a)
val a2 = dataset.col[B]('b)
val a3 = dataset.col[C]('c)
val a4 = dataset.col[D]('d)
- val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2).collect().run().toVector
+ val dataset2 =
+ dataset.select(a1, a2, a3, a4, a1, a3, a2).collect().run().toVector
val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b) }
dataset2 ?= data2
@@ -187,22 +196,24 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a,'b,'c,'d,'a,'c,'b, 'a) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- ed: TypedEncoder[D],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ ed: TypedEncoder[D],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val a1 = dataset.col[A]('a)
val a2 = dataset.col[B]('b)
val a3 = dataset.col[C]('c)
val a4 = dataset.col[D]('d)
- val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2, a1).collect().run().toVector
+ val dataset2 =
+ dataset.select(a1, a2, a3, a4, a1, a3, a2, a1).collect().run().toVector
val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b, a) }
dataset2 ?= data2
@@ -214,23 +225,30 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a,'b,'c,'d,'a,'c,'b,'a,'c) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- ed: TypedEncoder[D],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ ed: TypedEncoder[D],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val a1 = dataset.col[A]('a)
val a2 = dataset.col[B]('b)
val a3 = dataset.col[C]('c)
val a4 = dataset.col[D]('d)
- val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2, a1, a3).collect().run().toVector
- val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c) }
+ val dataset2 = dataset
+ .select(a1, a2, a3, a4, a1, a3, a2, a1, a3)
+ .collect()
+ .run()
+ .toVector
+ val data2 = data.map {
+ case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c)
+ }
dataset2 ?= data2
}
@@ -241,23 +259,30 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a,'b,'c,'d,'a,'c,'b,'a,'c, 'd) FROM abcd") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- eb: TypedEncoder[B],
- ec: TypedEncoder[C],
- ed: TypedEncoder[D],
- ex4: TypedEncoder[X4[A, B, C, D]],
- ca: ClassTag[A]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ eb: TypedEncoder[B],
+ ec: TypedEncoder[C],
+ ed: TypedEncoder[D],
+ ex4: TypedEncoder[X4[A, B, C, D]],
+ ca: ClassTag[A]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val a1 = dataset.col[A]('a)
val a2 = dataset.col[B]('b)
val a3 = dataset.col[C]('c)
val a4 = dataset.col[D]('d)
- val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2, a1, a3, a4).collect().run().toVector
- val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c, d) }
+ val dataset2 = dataset
+ .select(a1, a2, a3, a4, a1, a3, a2, a1, a3, a4)
+ .collect()
+ .run()
+ .toVector
+ val data2 = data.map {
+ case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c, d)
+ }
dataset2 ?= data2
}
@@ -268,12 +293,13 @@ class SelectTests extends TypedDatasetSuite {
}
test("select('a.b)") {
- def prop[A, B, C](data: Vector[X2[X2[A, B], C]])(
- implicit
- eabc: TypedEncoder[X2[X2[A, B], C]],
- eb: TypedEncoder[B],
- cb: ClassTag[B]
- ): Prop = {
+ def prop[A, B, C](
+ data: Vector[X2[X2[A, B], C]]
+ )(implicit
+ eabc: TypedEncoder[X2[X2[A, B], C]],
+ eb: TypedEncoder[B],
+ cb: ClassTag[B]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val AB = dataset.colMany('a, 'b)
@@ -287,13 +313,15 @@ class SelectTests extends TypedDatasetSuite {
}
test("select with column expression addition") {
- def prop[A](data: Vector[X1[A]], const: A)(
- implicit
- eabc: TypedEncoder[X1[A]],
- anum: CatalystNumeric[A],
- num: Numeric[A],
- eb: TypedEncoder[A]
- ): Prop = {
+ def prop[A](
+ data: Vector[X1[A]],
+ const: A
+ )(implicit
+ eabc: TypedEncoder[X1[A]],
+ anum: CatalystNumeric[A],
+ num: Numeric[A],
+ eb: TypedEncoder[A]
+ ): Prop = {
val ds = TypedDataset.create(data)
val dataset2 = ds.select(ds('a) + const).collect().run().toVector
@@ -309,13 +337,15 @@ class SelectTests extends TypedDatasetSuite {
}
test("select with column expression multiplication") {
- def prop[A](data: Vector[X1[A]], const: A)(
- implicit
- eabc: TypedEncoder[X1[A]],
- anum: CatalystNumeric[A],
- num: Numeric[A],
- eb: TypedEncoder[A]
- ): Prop = {
+ def prop[A](
+ data: Vector[X1[A]],
+ const: A
+ )(implicit
+ eabc: TypedEncoder[X1[A]],
+ anum: CatalystNumeric[A],
+ num: Numeric[A],
+ eb: TypedEncoder[A]
+ ): Prop = {
val ds = TypedDataset.create(data)
val dataset2 = ds.select(ds('a) * const).collect().run().toVector
@@ -331,13 +361,15 @@ class SelectTests extends TypedDatasetSuite {
}
test("select with column expression subtraction") {
- def prop[A](data: Vector[X1[A]], const: A)(
- implicit
- eabc: TypedEncoder[X1[A]],
- cnum: CatalystNumeric[A],
- num: Numeric[A],
- eb: TypedEncoder[A]
- ): Prop = {
+ def prop[A](
+ data: Vector[X1[A]],
+ const: A
+ )(implicit
+ eabc: TypedEncoder[X1[A]],
+ cnum: CatalystNumeric[A],
+ num: Numeric[A],
+ eb: TypedEncoder[A]
+ ): Prop = {
val ds = TypedDataset.create(data)
val dataset2 = ds.select(ds('a) - const).collect().run().toVector
@@ -352,17 +384,24 @@ class SelectTests extends TypedDatasetSuite {
}
test("select with column expression division") {
- def prop[A](data: Vector[X1[A]], const: A)(
- implicit
- eabc: TypedEncoder[X1[A]],
- anum: CatalystNumeric[A],
- frac: Fractional[A],
- eb: TypedEncoder[A]
- ): Prop = {
+ def prop[A](
+ data: Vector[X1[A]],
+ const: A
+ )(implicit
+ eabc: TypedEncoder[X1[A]],
+ anum: CatalystNumeric[A],
+ frac: Fractional[A],
+ eb: TypedEncoder[A]
+ ): Prop = {
val ds = TypedDataset.create(data)
if (const != 0) {
- val dataset2 = ds.select(ds('a) / const).collect().run().toVector.asInstanceOf[Vector[A]]
+ val dataset2 = ds
+ .select(ds('a) / const)
+ .collect()
+ .run()
+ .toVector
+ .asInstanceOf[Vector[A]]
val data2 = data.map { case X1(a) => frac.div(a, const) }
dataset2 ?= data2
} else 0 ?= 0
@@ -377,23 +416,36 @@ class SelectTests extends TypedDatasetSuite {
val t: TypedDataset[(Int, Int)] = e.select(e.col('i) * 2, e.col('i))
assert(t.select(t.col('_1)).collect().run().toList === List(2))
// Issue #54
- val fooT = t.select(t.col('_1)).deserialized.map(x => Tuple1.apply(x)).as[Foo]
+ val fooT =
+ t.select(t.col('_1)).deserialized.map(x => Tuple1.apply(x)).as[Foo]
assert(fooT.select(fooT('i)).collect().run().toList === List(2))
}
test("unary - on arithmetic") {
- val e = TypedDataset.create[(Int, String, Int)]((1, "a", 2) :: (2, "b", 4) :: (2, "b", 1) :: Nil)
+ val e = TypedDataset.create[(Int, String, Int)](
+ (1, "a", 2) :: (2, "b", 4) :: (2, "b", 1) :: Nil
+ )
assert(e.select(-e('_1)).collect().run().toVector === Vector(-1, -2, -2))
- assert(e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(-3, -6, -3))
+ assert(
+ e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(
+ -3,
+ -6,
+ -3
+ )
+ )
}
test("unary - on strings should not type check") {
- val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
+ val e = TypedDataset.create[(Int, String, Long)](
+ (1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil
+ )
illTyped("""e.select( -e('_2) )""")
}
test("select with aggregation operations is not supported") {
- val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil)
+ val e = TypedDataset.create[(Int, String, Long)](
+ (1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil
+ )
illTyped("""e.select(frameless.functions.aggregate.sum(e('_1)))""")
}
}
diff --git a/dataset/src/test/scala/frameless/SelfJoinTests.scala b/dataset/src/test/scala/frameless/SelfJoinTests.scala
index cede7be2a..847a23e0d 100644
--- a/dataset/src/test/scala/frameless/SelfJoinTests.scala
+++ b/dataset/src/test/scala/frameless/SelfJoinTests.scala
@@ -2,13 +2,18 @@ package frameless
import org.scalacheck.Prop
import org.scalacheck.Prop._
-import org.apache.spark.sql.{SparkSession, functions => sparkFunctions}
+import org.apache.spark.sql.{ SparkSession, functions => sparkFunctions }
class SelfJoinTests extends TypedDatasetSuite {
+
// Without crossJoin.enabled=true Spark doesn't like trivial join conditions:
// [error] Join condition is missing or trivial.
// [error] Use the CROSS JOIN syntax to allow cartesian products between these relations.
- def allowTrivialJoin[T](body: => T)(implicit session: SparkSession): T = {
+ def allowTrivialJoin[T](
+ body: => T
+ )(implicit
+ session: SparkSession
+ ): T = {
val crossJoin = "spark.sql.crossJoin.enabled"
val oldSetting = session.conf.get(crossJoin)
session.conf.set(crossJoin, "true")
@@ -17,7 +22,11 @@ class SelfJoinTests extends TypedDatasetSuite {
result
}
- def allowAmbiguousJoin[T](body: => T)(implicit session: SparkSession): T = {
+ def allowAmbiguousJoin[T](
+ body: => T
+ )(implicit
+ session: SparkSession
+ ): T = {
val crossJoin = "spark.sql.analyzer.failAmbiguousSelfJoin"
val oldSetting = session.conf.get(crossJoin)
session.conf.set(crossJoin, "false")
@@ -27,22 +36,26 @@ class SelfJoinTests extends TypedDatasetSuite {
}
test("self join with colLeft/colRight disambiguation") {
- def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering
- ](dx: List[X2[A, B]], d: X2[A, B]): Prop = allowAmbiguousJoin {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ dx: List[X2[A, B]],
+ d: X2[A, B]
+ ): Prop = allowAmbiguousJoin {
val data = d :: dx
val ds = TypedDataset.create(data)
// This is the way to write unambiguous self-join in vanilla, see https://goo.gl/XnkSUD
val df1 = ds.dataset.as("df1")
val df2 = ds.dataset.as("df2")
- val vanilla = df1.join(df2,
- sparkFunctions.col("df1.a") === sparkFunctions.col("df2.a")).count()
+ val vanilla = df1
+ .join(df2, sparkFunctions.col("df1.a") === sparkFunctions.col("df2.a"))
+ .count()
- val typed = ds.joinInner(ds)(
- ds.colLeft('a) === ds.colRight('a)
- ).count().run()
+ val typed = ds
+ .joinInner(ds)(
+ ds.colLeft('a) === ds.colRight('a)
+ )
+ .count()
+ .run()
vanilla ?= typed
}
@@ -51,47 +64,60 @@ class SelfJoinTests extends TypedDatasetSuite {
}
test("trivial self join") {
- def prop[
- A : TypedEncoder : Ordering,
- B : TypedEncoder : Ordering
- ](dx: List[X2[A, B]], d: X2[A, B]): Prop =
- allowTrivialJoin { allowAmbiguousJoin {
-
- val data = d :: dx
- val ds = TypedDataset.create(data)
- val untyped = ds.dataset
- // Interestingly, even with aliasing it seems that it's impossible to
- // obtain a trivial join condition of shape df1.a == df1.a, Spark we
- // always interpret that as df1.a == df2.a. For the purpose of this
- // test we fall-back to lit(true) instead.
- // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a")
- val trivial = sparkFunctions.lit(true)
- val vanilla = untyped.as("df1").join(untyped.as("df2"), trivial).count()
-
- val typed = ds.joinInner(ds)(ds.colLeft('a) === ds.colLeft('a)).count().run
- vanilla ?= typed
- } }
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ dx: List[X2[A, B]],
+ d: X2[A, B]
+ ): Prop =
+ allowTrivialJoin {
+ allowAmbiguousJoin {
+
+ val data = d :: dx
+ val ds = TypedDataset.create(data)
+ val untyped = ds.dataset
+ // Interestingly, even with aliasing it seems that it's impossible to
+ // obtain a trivial join condition of shape df1.a == df1.a, Spark we
+ // always interpret that as df1.a == df2.a. For the purpose of this
+ // test we fall-back to lit(true) instead.
+ // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a")
+ val trivial = sparkFunctions.lit(true)
+ val vanilla =
+ untyped.as("df1").join(untyped.as("df2"), trivial).count()
+
+ val typed =
+ ds.joinInner(ds)(ds.colLeft('a) === ds.colLeft('a)).count().run
+ vanilla ?= typed
+ }
+ }
check(prop[Int, Int] _)
}
test("self join with unambiguous expression") {
def prop[
- A : TypedEncoder : CatalystNumeric : Ordering,
- B : TypedEncoder : Ordering
- ](data: List[X3[A, A, B]]): Prop = allowAmbiguousJoin {
+ A: TypedEncoder: CatalystNumeric: Ordering,
+ B: TypedEncoder: Ordering
+ ](data: List[X3[A, A, B]]
+ ): Prop = allowAmbiguousJoin {
val ds = TypedDataset.create(data)
val df1 = ds.dataset.alias("df1")
val df2 = ds.dataset.alias("df2")
- val vanilla = df1.join(df2,
- (sparkFunctions.col("df1.a") + sparkFunctions.col("df1.b")) ===
- (sparkFunctions.col("df2.a") + sparkFunctions.col("df2.b"))).count()
-
- val typed = ds.joinInner(ds)(
- (ds.colLeft('a) + ds.colLeft('b)) === (ds.colRight('a) + ds.colRight('b))
- ).count().run()
+ val vanilla = df1
+ .join(
+ df2,
+ (sparkFunctions.col("df1.a") + sparkFunctions.col("df1.b")) ===
+ (sparkFunctions.col("df2.a") + sparkFunctions.col("df2.b"))
+ )
+ .count()
+
+ val typed = ds
+ .joinInner(ds)(
+ (ds.colLeft('a) + ds.colLeft('b)) === (ds.colRight('a) + ds
+ .colRight('b))
+ )
+ .count()
+ .run()
vanilla ?= typed
}
@@ -99,41 +125,57 @@ class SelfJoinTests extends TypedDatasetSuite {
check(prop[Int, Int] _)
}
- test("Do you want ambiguous self join? This is how you get ambiguous self join.") {
+ test(
+ "Do you want ambiguous self join? This is how you get ambiguous self join."
+ ) {
def prop[
- A : TypedEncoder : CatalystNumeric : Ordering,
- B : TypedEncoder : Ordering
- ](data: List[X3[A, A, B]]): Prop =
- allowTrivialJoin { allowAmbiguousJoin {
- val ds = TypedDataset.create(data)
-
- // The point I'm making here is that it "behaves just like Spark". I
- // don't know (or really care about how) how Spark disambiguates that
- // internally...
- val vanilla = ds.dataset.join(ds.dataset,
- (ds.dataset("a") + ds.dataset("b")) ===
- (ds.dataset("a") + ds.dataset("b"))).count()
-
- val typed = ds.joinInner(ds)(
- (ds.col('a) + ds.col('b)) === (ds.col('a) + ds.col('b))
- ).count().run()
-
- vanilla ?= typed
- } }
-
- check(prop[Int, Int] _)
- }
+ A: TypedEncoder: CatalystNumeric: Ordering,
+ B: TypedEncoder: Ordering
+ ](data: List[X3[A, A, B]]
+ ): Prop =
+ allowTrivialJoin {
+ allowAmbiguousJoin {
+ val ds = TypedDataset.create(data)
+
+ // The point I'm making here is that it "behaves just like Spark". I
+ // don't know (or really care about how) how Spark disambiguates that
+ // internally...
+ val vanilla = ds.dataset
+ .join(
+ ds.dataset,
+ (ds.dataset("a") + ds.dataset("b")) ===
+ (ds.dataset("a") + ds.dataset("b"))
+ )
+ .count()
+
+ val typed = ds
+ .joinInner(ds)(
+ (ds.col('a) + ds.col('b)) === (ds.col('a) + ds.col('b))
+ )
+ .count()
+ .run()
+
+ vanilla ?= typed
+ }
+ }
+
+ check(prop[Int, Int] _)
+ }
test("colLeft and colRight are equivalent to col outside of joins") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- ex4: TypedEncoder[X4[A, B, C, D]]
- ): Prop = {
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ ex4: TypedEncoder[X4[A, B, C, D]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
- val selectedCol = dataset.select(dataset.col [A]('a)).collect().run().toVector
- val selectedColLeft = dataset.select(dataset.colLeft [A]('a)).collect().run().toVector
- val selectedColRight = dataset.select(dataset.colRight[A]('a)).collect().run().toVector
+ val selectedCol =
+ dataset.select(dataset.col[A]('a)).collect().run().toVector
+ val selectedColLeft =
+ dataset.select(dataset.colLeft[A]('a)).collect().run().toVector
+ val selectedColRight =
+ dataset.select(dataset.colRight[A]('a)).collect().run().toVector
(selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight)
}
@@ -145,16 +187,26 @@ class SelfJoinTests extends TypedDatasetSuite {
}
test("colLeft and colRight are equivalent to col outside of joins - via files (codegen)") {
- def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])(
- implicit
- ea: TypedEncoder[A],
- ex4: TypedEncoder[X4[A, B, C, D]]
- ): Prop = {
- TypedDataset.create(data).write.mode("overwrite").parquet("./target/testData")
- val dataset = TypedDataset.createUnsafe[X4[A, B, C, D]](session.read.parquet("./target/testData"))
- val selectedCol = dataset.select(dataset.col [A]('a)).collect().run().toVector
- val selectedColLeft = dataset.select(dataset.colLeft [A]('a)).collect().run().toVector
- val selectedColRight = dataset.select(dataset.colRight[A]('a)).collect().run().toVector
+ def prop[A, B, C, D](
+ data: Vector[X4[A, B, C, D]]
+ )(implicit
+ ea: TypedEncoder[A],
+ ex4: TypedEncoder[X4[A, B, C, D]]
+ ): Prop = {
+ TypedDataset
+ .create(data)
+ .write
+ .mode("overwrite")
+ .parquet("./target/testData")
+ val dataset = TypedDataset.createUnsafe[X4[A, B, C, D]](
+ session.read.parquet("./target/testData")
+ )
+ val selectedCol =
+ dataset.select(dataset.col[A]('a)).collect().run().toVector
+ val selectedColLeft =
+ dataset.select(dataset.colLeft[A]('a)).collect().run().toVector
+ val selectedColRight =
+ dataset.select(dataset.colRight[A]('a)).collect().run().toVector
(selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight)
}
diff --git a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala
index 8a4697835..0845b39f8 100644
--- a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala
+++ b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala
@@ -2,28 +2,35 @@ package frameless
import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem
import org.apache.hadoop.fs.local.StreamingFS
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.{SQLContext, SparkSession}
+import org.apache.spark.{ SparkConf, SparkContext }
+import org.apache.spark.sql.{ SQLContext, SparkSession }
import org.scalactic.anyvals.PosZInt
import org.scalatest.BeforeAndAfterAll
import org.scalatestplus.scalacheck.Checkers
import org.scalacheck.Prop
import org.scalacheck.Prop._
-import scala.util.{Properties, Try}
+import scala.util.{ Properties, Try }
import org.scalatest.funsuite.AnyFunSuite
trait SparkTesting { self: BeforeAndAfterAll =>
- val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString
+ val appID: String = new java.util.Date().toString + math
+ .floor(math.random * 10e4)
+ .toLong
+ .toString
/**
* Allows bare naked to be used instead of winutils for testing / dev
*/
def registerFS(sparkConf: SparkConf): SparkConf = {
if (System.getProperty("os.name").startsWith("Windows"))
- sparkConf.set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName).
- set("spark.hadoop.fs.AbstractFileSystem.file.impl", classOf[StreamingFS].getName)
+ sparkConf
+ .set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName)
+ .set(
+ "spark.hadoop.fs.AbstractFileSystem.file.impl",
+ classOf[StreamingFS].getName
+ )
else
sparkConf
}
@@ -40,9 +47,9 @@ trait SparkTesting { self: BeforeAndAfterAll =>
implicit def sc: SparkContext = session.sparkContext
implicit def sqlContext: SQLContext = session.sqlContext
- def registerOptimizations(sqlContext: SQLContext): Unit = { }
+ def registerOptimizations(sqlContext: SQLContext): Unit = {}
- def addSparkConfigProperties(config: SparkConf): Unit = { }
+ def addSparkConfigProperties(config: SparkConf): Unit = {}
override def beforeAll(): Unit = {
assert(s == null)
@@ -59,11 +66,16 @@ trait SparkTesting { self: BeforeAndAfterAll =>
}
}
+class TypedDatasetSuite
+ extends AnyFunSuite
+ with Checkers
+ with BeforeAndAfterAll
+ with SparkTesting {
-class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll with SparkTesting {
// Limit size of generated collections and number of checks to avoid OutOfMemoryError
implicit override val generatorDrivenConfig: PropertyCheckConfiguration = {
- def getPosZInt(name: String, default: PosZInt) = Properties.envOrNone(s"FRAMELESS_GEN_${name}")
+ def getPosZInt(name: String, default: PosZInt) = Properties
+ .envOrNone(s"FRAMELESS_GEN_${name}")
.flatMap(s => Try(s.toInt).toOption)
.flatMap(PosZInt.from)
.getOrElse(default)
@@ -75,17 +87,24 @@ class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll
implicit val sparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob
- def approximatelyEqual[A](a: A, b: A)(implicit numeric: Numeric[A]): Prop = {
+ def approximatelyEqual[A](
+ a: A,
+ b: A
+ )(implicit
+ numeric: Numeric[A]
+ ): Prop = {
val da = numeric.toDouble(a)
val db = numeric.toDouble(b)
- val epsilon = 1E-6
+ val epsilon = 1e-6
// Spark has a weird behaviour concerning expressions that should return Inf
// Most of the time they return NaN instead, for instance stddev of Seq(-7.827553978923477E227, -5.009124275715786E153)
- if((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved
+ if ((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved
else if (
(da - db).abs < epsilon ||
- (da - db).abs < da.abs / 100)
- proved
- else falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon."
+ (da - db).abs < da.abs / 100
+ )
+ proved
+ else
+ falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon."
}
}
diff --git a/dataset/src/test/scala/frameless/UdtEncodedClass.scala b/dataset/src/test/scala/frameless/UdtEncodedClass.scala
index 4e5c2c6d9..917e4c1f1 100644
--- a/dataset/src/test/scala/frameless/UdtEncodedClass.scala
+++ b/dataset/src/test/scala/frameless/UdtEncodedClass.scala
@@ -1,14 +1,19 @@
package frameless
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
+import org.apache.spark.sql.catalyst.expressions.{
+ GenericInternalRow,
+ UnsafeArrayData
+}
import org.apache.spark.sql.types._
import org.apache.spark.sql.FramelessInternals.UserDefinedType
@SQLUserDefinedType(udt = classOf[UdtEncodedClassUdt])
class UdtEncodedClass(val a: Int, val b: Array[Double]) {
+
override def equals(other: Any): Boolean = other match {
- case that: UdtEncodedClass => a == that.a && java.util.Arrays.equals(b, that.b)
+ case that: UdtEncodedClass =>
+ a == that.a && java.util.Arrays.equals(b, that.b)
case _ => false
}
@@ -25,11 +30,18 @@ object UdtEncodedClass {
}
class UdtEncodedClassUdt extends UserDefinedType[UdtEncodedClass] {
+
def sqlType: DataType = {
- StructType(Seq(
- StructField("a", IntegerType, nullable = false),
- StructField("b", ArrayType(DoubleType, containsNull = false), nullable = false)
- ))
+ StructType(
+ Seq(
+ StructField("a", IntegerType, nullable = false),
+ StructField(
+ "b",
+ ArrayType(DoubleType, containsNull = false),
+ nullable = false
+ )
+ )
+ )
}
def serialize(obj: UdtEncodedClass): InternalRow = {
@@ -40,7 +52,8 @@ class UdtEncodedClassUdt extends UserDefinedType[UdtEncodedClass] {
}
def deserialize(datum: Any): UdtEncodedClass = datum match {
- case row: InternalRow => new UdtEncodedClass(row.getInt(0), row.getArray(1).toDoubleArray())
+ case row: InternalRow =>
+ new UdtEncodedClass(row.getInt(0), row.getArray(1).toDoubleArray())
}
def userClass: Class[UdtEncodedClass] = classOf[UdtEncodedClass]
diff --git a/dataset/src/test/scala/frameless/WithColumnTest.scala b/dataset/src/test/scala/frameless/WithColumnTest.scala
index c41c4e726..4d62a4f05 100644
--- a/dataset/src/test/scala/frameless/WithColumnTest.scala
+++ b/dataset/src/test/scala/frameless/WithColumnTest.scala
@@ -8,28 +8,32 @@ class WithColumnTest extends TypedDatasetSuite {
import WithColumnTest._
test("fail to compile on missing value") {
- val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
+ val f: TypedDataset[X] =
+ TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XMissing] = f.withColumn[XMissing](f('j) === 10)"""
}
}
test("fail to compile on different column name") {
- val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
+ val f: TypedDataset[X] =
+ TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XDifferentColumnName] = f.withColumn[XDifferentColumnName](f('j) === 10)"""
}
}
test("fail to compile on added column name") {
- val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
+ val f: TypedDataset[X] =
+ TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XAdded] = f.withColumn[XAdded](f('j) === 10)"""
}
}
test("fail to compile on wrong typed column") {
- val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
+ val f: TypedDataset[X] =
+ TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil)
illTyped {
"""val fNew: TypedDataset[XWrongType] = f.withColumn[XWrongType](f('j) === 10)"""
}
@@ -54,13 +58,10 @@ class WithColumnTest extends TypedDatasetSuite {
}
test("update in place") {
- def prop[A : TypedEncoder](startValue: A, replaceValue: A): Prop = {
+ def prop[A: TypedEncoder](startValue: A, replaceValue: A): Prop = {
val d = TypedDataset.create(X2(startValue, replaceValue) :: Nil)
- val X2(a, b) = d.withColumnReplaced('a, d('b))
- .collect()
- .run()
- .head
+ val X2(a, b) = d.withColumnReplaced('a, d('b)).collect().run().head
a ?= b
}
diff --git a/dataset/src/test/scala/frameless/XN.scala b/dataset/src/test/scala/frameless/XN.scala
index c23d4b45d..a92f7d9d1 100644
--- a/dataset/src/test/scala/frameless/XN.scala
+++ b/dataset/src/test/scala/frameless/XN.scala
@@ -1,14 +1,18 @@
package frameless
-import org.scalacheck.{Arbitrary, Cogen}
+import org.scalacheck.{ Arbitrary, Cogen }
case class X1[A](a: A)
object X1 {
+
implicit def arbitrary[A: Arbitrary]: Arbitrary[X1[A]] =
Arbitrary(implicitly[Arbitrary[A]].arbitrary.map(X1(_)))
- implicit def cogen[A](implicit A: Cogen[A]): Cogen[X1[A]] =
+ implicit def cogen[A](
+ implicit
+ A: Cogen[A]
+ ): Cogen[X1[A]] =
A.contramap(_.a)
implicit def ordering[A: Ordering]: Ordering[X1[A]] = Ordering[A].on(_.a)
@@ -17,89 +21,225 @@ object X1 {
case class X2[A, B](a: A, b: B)
object X2 {
- implicit def arbitrary[A: Arbitrary, B: Arbitrary]: Arbitrary[X2[A, B]] =
- Arbitrary(Arbitrary.arbTuple2[A, B].arbitrary.map((X2.apply[A, B] _).tupled))
- implicit def cogen[A, B](implicit A: Cogen[A], B: Cogen[B]): Cogen[X2[A, B]] =
+ implicit def arbitrary[A: Arbitrary, B: Arbitrary]: Arbitrary[X2[A, B]] =
+ Arbitrary(
+ Arbitrary.arbTuple2[A, B].arbitrary.map((X2.apply[A, B] _).tupled)
+ )
+
+ implicit def cogen[A, B](
+ implicit
+ A: Cogen[A],
+ B: Cogen[B]
+ ): Cogen[X2[A, B]] =
Cogen.tuple2(A, B).contramap(x => (x.a, x.b))
- implicit def ordering[A: Ordering, B: Ordering]: Ordering[X2[A, B]] = Ordering.Tuple2[A, B].on(x => (x.a, x.b))
+ implicit def ordering[A: Ordering, B: Ordering]: Ordering[X2[A, B]] =
+ Ordering.Tuple2[A, B].on(x => (x.a, x.b))
}
case class X3[A, B, C](a: A, b: B, c: C)
object X3 {
- implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3[A, B, C]] =
- Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3.apply[A, B, C] _).tupled))
- implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3[A, B, C]] =
+ implicit def arbitrary[
+ A: Arbitrary,
+ B: Arbitrary,
+ C: Arbitrary
+ ]: Arbitrary[X3[A, B, C]] =
+ Arbitrary(
+ Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3.apply[A, B, C] _).tupled)
+ )
+
+ implicit def cogen[A, B, C](
+ implicit
+ A: Cogen[A],
+ B: Cogen[B],
+ C: Cogen[C]
+ ): Cogen[X3[A, B, C]] =
Cogen.tuple3(A, B, C).contramap(x => (x.a, x.b, x.c))
- implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3[A, B, C]] =
+ implicit def ordering[
+ A: Ordering,
+ B: Ordering,
+ C: Ordering
+ ]: Ordering[X3[A, B, C]] =
Ordering.Tuple3[A, B, C].on(x => (x.a, x.b, x.c))
}
case class X3U[A, B, C](a: A, b: B, u: Unit, c: C)
object X3U {
- implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3U[A, B, C]] =
- Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map(x => X3U[A, B, C](x._1, x._2, (), x._3)))
- implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3U[A, B, C]] =
+ implicit def arbitrary[
+ A: Arbitrary,
+ B: Arbitrary,
+ C: Arbitrary
+ ]: Arbitrary[X3U[A, B, C]] =
+ Arbitrary(
+ Arbitrary
+ .arbTuple3[A, B, C]
+ .arbitrary
+ .map(x => X3U[A, B, C](x._1, x._2, (), x._3))
+ )
+
+ implicit def cogen[A, B, C](
+ implicit
+ A: Cogen[A],
+ B: Cogen[B],
+ C: Cogen[C]
+ ): Cogen[X3U[A, B, C]] =
Cogen.tuple3(A, B, C).contramap(x => (x.a, x.b, x.c))
- implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3U[A, B, C]] =
+ implicit def ordering[
+ A: Ordering,
+ B: Ordering,
+ C: Ordering
+ ]: Ordering[X3U[A, B, C]] =
Ordering.Tuple3[A, B, C].on(x => (x.a, x.b, x.c))
}
case class X3KV[A, B, C](key: A, value: B, c: C)
object X3KV {
- implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3KV[A, B, C]] =
- Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3KV.apply[A, B, C] _).tupled))
- implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3KV[A, B, C]] =
+ implicit def arbitrary[
+ A: Arbitrary,
+ B: Arbitrary,
+ C: Arbitrary
+ ]: Arbitrary[X3KV[A, B, C]] =
+ Arbitrary(
+ Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3KV.apply[A, B, C] _).tupled)
+ )
+
+ implicit def cogen[A, B, C](
+ implicit
+ A: Cogen[A],
+ B: Cogen[B],
+ C: Cogen[C]
+ ): Cogen[X3KV[A, B, C]] =
Cogen.tuple3(A, B, C).contramap(x => (x.key, x.value, x.c))
- implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3KV[A, B, C]] =
+ implicit def ordering[
+ A: Ordering,
+ B: Ordering,
+ C: Ordering
+ ]: Ordering[X3KV[A, B, C]] =
Ordering.Tuple3[A, B, C].on(x => (x.key, x.value, x.c))
}
case class X4[A, B, C, D](a: A, b: B, c: C, d: D)
object X4 {
- implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary]: Arbitrary[X4[A, B, C, D]] =
- Arbitrary(Arbitrary.arbTuple4[A, B, C, D].arbitrary.map((X4.apply[A, B, C, D] _).tupled))
- implicit def cogen[A, B, C, D](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C], D: Cogen[D]): Cogen[X4[A, B, C, D]] =
+ implicit def arbitrary[
+ A: Arbitrary,
+ B: Arbitrary,
+ C: Arbitrary,
+ D: Arbitrary
+ ]: Arbitrary[X4[A, B, C, D]] =
+ Arbitrary(
+ Arbitrary
+ .arbTuple4[A, B, C, D]
+ .arbitrary
+ .map((X4.apply[A, B, C, D] _).tupled)
+ )
+
+ implicit def cogen[A, B, C, D](
+ implicit
+ A: Cogen[A],
+ B: Cogen[B],
+ C: Cogen[C],
+ D: Cogen[D]
+ ): Cogen[X4[A, B, C, D]] =
Cogen.tuple4(A, B, C, D).contramap(x => (x.a, x.b, x.c, x.d))
- implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering]: Ordering[X4[A, B, C, D]] =
+ implicit def ordering[
+ A: Ordering,
+ B: Ordering,
+ C: Ordering,
+ D: Ordering
+ ]: Ordering[X4[A, B, C, D]] =
Ordering.Tuple4[A, B, C, D].on(x => (x.a, x.b, x.c, x.d))
}
case class X5[A, B, C, D, E](a: A, b: B, c: C, d: D, e: E)
object X5 {
- implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary, E: Arbitrary]: Arbitrary[X5[A, B, C, D, E]] =
- Arbitrary(Arbitrary.arbTuple5[A, B, C, D, E].arbitrary.map((X5.apply[A, B, C, D, E] _).tupled))
- implicit def cogen[A, B, C, D, E](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C], D: Cogen[D], E: Cogen[E]): Cogen[X5[A, B, C, D, E]] =
+ implicit def arbitrary[
+ A: Arbitrary,
+ B: Arbitrary,
+ C: Arbitrary,
+ D: Arbitrary,
+ E: Arbitrary
+ ]: Arbitrary[X5[A, B, C, D, E]] =
+ Arbitrary(
+ Arbitrary
+ .arbTuple5[A, B, C, D, E]
+ .arbitrary
+ .map((X5.apply[A, B, C, D, E] _).tupled)
+ )
+
+ implicit def cogen[A, B, C, D, E](
+ implicit
+ A: Cogen[A],
+ B: Cogen[B],
+ C: Cogen[C],
+ D: Cogen[D],
+ E: Cogen[E]
+ ): Cogen[X5[A, B, C, D, E]] =
Cogen.tuple5(A, B, C, D, E).contramap(x => (x.a, x.b, x.c, x.d, x.e))
- implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] =
+ implicit def ordering[
+ A: Ordering,
+ B: Ordering,
+ C: Ordering,
+ D: Ordering,
+ E: Ordering
+ ]: Ordering[X5[A, B, C, D, E]] =
Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e))
}
case class X6[A, B, C, D, E, F](a: A, b: B, c: C, d: D, e: E, f: F)
object X6 {
- implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary, E: Arbitrary, F: Arbitrary]: Arbitrary[X6[A, B, C, D, E, F]] =
- Arbitrary(Arbitrary.arbTuple6[A, B, C, D, E, F].arbitrary.map((X6.apply[A, B, C, D, E, F] _).tupled))
- implicit def cogen[A, B, C, D, E, F](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C], D: Cogen[D], E: Cogen[E], F: Cogen[F]): Cogen[X6[A, B, C, D, E, F]] =
- Cogen.tuple6(A, B, C, D, E, F).contramap(x => (x.a, x.b, x.c, x.d, x.e, x.f))
-
- implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering, F: Ordering]: Ordering[X6[A, B, C, D, E, F]] =
+ implicit def arbitrary[
+ A: Arbitrary,
+ B: Arbitrary,
+ C: Arbitrary,
+ D: Arbitrary,
+ E: Arbitrary,
+ F: Arbitrary
+ ]: Arbitrary[X6[A, B, C, D, E, F]] =
+ Arbitrary(
+ Arbitrary
+ .arbTuple6[A, B, C, D, E, F]
+ .arbitrary
+ .map((X6.apply[A, B, C, D, E, F] _).tupled)
+ )
+
+ implicit def cogen[A, B, C, D, E, F](
+ implicit
+ A: Cogen[A],
+ B: Cogen[B],
+ C: Cogen[C],
+ D: Cogen[D],
+ E: Cogen[E],
+ F: Cogen[F]
+ ): Cogen[X6[A, B, C, D, E, F]] =
+ Cogen
+ .tuple6(A, B, C, D, E, F)
+ .contramap(x => (x.a, x.b, x.c, x.d, x.e, x.f))
+
+ implicit def ordering[
+ A: Ordering,
+ B: Ordering,
+ C: Ordering,
+ D: Ordering,
+ E: Ordering,
+ F: Ordering
+ ]: Ordering[X6[A, B, C, D, E, F]] =
Ordering.Tuple6[A, B, C, D, E, F].on(x => (x.a, x.b, x.c, x.d, x.e, x.f))
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/CheckpointTests.scala b/dataset/src/test/scala/frameless/forward/CheckpointTests.scala
index 9a1ff8b44..d982f0344 100644
--- a/dataset/src/test/scala/frameless/forward/CheckpointTests.scala
+++ b/dataset/src/test/scala/frameless/forward/CheckpointTests.scala
@@ -1,8 +1,7 @@
package frameless
import org.scalacheck.Prop
-import org.scalacheck.Prop.{forAll, _}
-
+import org.scalacheck.Prop.{ forAll, _ }
class CheckpointTests extends TypedDatasetSuite {
test("checkpoint") {
@@ -18,4 +17,4 @@ class CheckpointTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/ColumnsTests.scala b/dataset/src/test/scala/frameless/forward/ColumnsTests.scala
index 282a72c9a..27e20511d 100644
--- a/dataset/src/test/scala/frameless/forward/ColumnsTests.scala
+++ b/dataset/src/test/scala/frameless/forward/ColumnsTests.scala
@@ -5,7 +5,14 @@ import org.scalacheck.Prop.forAll
class ColumnsTests extends TypedDatasetSuite {
test("columns") {
- def prop(i: Int, s: String, b: Boolean, l: Long, d: Double, by: Byte): Prop = {
+ def prop(
+ i: Int,
+ s: String,
+ b: Boolean,
+ l: Long,
+ d: Double,
+ by: Byte
+ ): Prop = {
val x1 = X1(i) :: Nil
val x2 = X2(i, s) :: Nil
val x3 = X3(i, s, b) :: Nil
@@ -13,18 +20,21 @@ class ColumnsTests extends TypedDatasetSuite {
val x5 = X5(i, s, b, l, d) :: Nil
val x6 = X6(i, s, b, l, d, by) :: Nil
- val datasets = Seq(TypedDataset.create(x1), TypedDataset.create(x2),
- TypedDataset.create(x3), TypedDataset.create(x4),
- TypedDataset.create(x5), TypedDataset.create(x6))
+ val datasets = Seq(
+ TypedDataset.create(x1),
+ TypedDataset.create(x2),
+ TypedDataset.create(x3),
+ TypedDataset.create(x4),
+ TypedDataset.create(x5),
+ TypedDataset.create(x6)
+ )
Prop.all(datasets.flatMap { dataset =>
val columns = dataset.dataset.columns
- dataset.columns.map(col =>
- Prop.propBoolean(columns contains col)
- )
+ dataset.columns.map(col => Prop.propBoolean(columns contains col))
}: _*)
}
check(forAll(prop _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/DistinctTests.scala b/dataset/src/test/scala/frameless/forward/DistinctTests.scala
index 44da5e59e..bd2b4be41 100644
--- a/dataset/src/test/scala/frameless/forward/DistinctTests.scala
+++ b/dataset/src/test/scala/frameless/forward/DistinctTests.scala
@@ -7,8 +7,14 @@ import math.Ordering
class DistinctTests extends TypedDatasetSuite {
test("distinct") {
// Comparison done with `.sorted` because order is not preserved by Spark for this operation.
- def prop[A: TypedEncoder : Ordering](data: Vector[A]): Prop =
- TypedDataset.create(data).distinct.collect().run().toVector.sorted ?= data.distinct.sorted
+ def prop[A: TypedEncoder: Ordering](data: Vector[A]): Prop =
+ TypedDataset
+ .create(data)
+ .distinct
+ .collect()
+ .run()
+ .toVector
+ .sorted ?= data.distinct.sorted
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
diff --git a/dataset/src/test/scala/frameless/forward/HeadTests.scala b/dataset/src/test/scala/frameless/forward/HeadTests.scala
index 63f76e003..c3314fc40 100644
--- a/dataset/src/test/scala/frameless/forward/HeadTests.scala
+++ b/dataset/src/test/scala/frameless/forward/HeadTests.scala
@@ -1,6 +1,12 @@
package frameless.forward
-import frameless.{TypedDataset, TypedDatasetSuite, TypedEncoder, TypedExpressionEncoder, X1}
+import frameless.{
+ TypedDataset,
+ TypedDatasetSuite,
+ TypedEncoder,
+ TypedExpressionEncoder,
+ X1
+}
import org.apache.spark.sql.SparkSession
import org.scalacheck.Prop
import org.scalacheck.Prop._
@@ -9,17 +15,25 @@ import scala.reflect.ClassTag
import org.scalatest.matchers.should.Matchers
class HeadTests extends TypedDatasetSuite with Matchers {
- def propArray[A: TypedEncoder : ClassTag : Ordering](data: Vector[X1[A]])(implicit c: SparkSession): Prop = {
+
+ def propArray[A: TypedEncoder: ClassTag: Ordering](
+ data: Vector[X1[A]]
+ )(implicit
+ c: SparkSession
+ ): Prop = {
import c.implicits._
- if(data.nonEmpty) {
- val tds = TypedDataset.
- create(c.createDataset(data)(
+ if (data.nonEmpty) {
+ val tds = TypedDataset.create(
+ c.createDataset(data)(
TypedExpressionEncoder.apply[X1[A]]
- ).orderBy($"a".desc))
- (tds.headOption().run().get ?= data.max).
- &&(tds.head(1).run().head ?= data.max).
- &&(tds.head(4).run().toVector ?=
- data.sortBy(_.a)(implicitly[Ordering[A]].reverse).take(4))
+ ).orderBy($"a".desc)
+ )
+ (tds.headOption().run().get ?= data.max)
+ .&&(tds.head(1).run().head ?= data.max)
+ .&&(
+ tds.head(4).run().toVector ?=
+ data.sortBy(_.a)(implicitly[Ordering[A]].reverse).take(4)
+ )
} else Prop.passed
}
diff --git a/dataset/src/test/scala/frameless/forward/InputFilesTests.scala b/dataset/src/test/scala/frameless/forward/InputFilesTests.scala
index 246867e63..461457879 100644
--- a/dataset/src/test/scala/frameless/forward/InputFilesTests.scala
+++ b/dataset/src/test/scala/frameless/forward/InputFilesTests.scala
@@ -14,7 +14,9 @@ class InputFilesTests extends TypedDatasetSuite with Matchers {
val filePath = s"$TEST_OUTPUT_DIR/${UUID.randomUUID()}.txt"
TypedDataset.create(data).dataset.write.text(filePath)
- val dataset = TypedDataset.create(implicitly[SparkSession].sparkContext.textFile(filePath))
+ val dataset = TypedDataset.create(
+ implicitly[SparkSession].sparkContext.textFile(filePath)
+ )
dataset.inputFiles sameElements dataset.dataset.inputFiles
}
@@ -25,7 +27,10 @@ class InputFilesTests extends TypedDatasetSuite with Matchers {
inputDataset.dataset.write.csv(filePath)
val dataset = TypedDataset.createUnsafe(
- implicitly[SparkSession].sqlContext.read.schema(inputDataset.schema).csv(filePath))
+ implicitly[SparkSession].sqlContext.read
+ .schema(inputDataset.schema)
+ .csv(filePath)
+ )
dataset.inputFiles sameElements dataset.dataset.inputFiles
}
@@ -36,7 +41,10 @@ class InputFilesTests extends TypedDatasetSuite with Matchers {
inputDataset.dataset.write.json(filePath)
val dataset = TypedDataset.createUnsafe(
- implicitly[SparkSession].sqlContext.read.schema(inputDataset.schema).json(filePath))
+ implicitly[SparkSession].sqlContext.read
+ .schema(inputDataset.schema)
+ .json(filePath)
+ )
dataset.inputFiles sameElements dataset.dataset.inputFiles
}
@@ -45,4 +53,4 @@ class InputFilesTests extends TypedDatasetSuite with Matchers {
check(forAll(propCsv[String] _))
check(forAll(propJson[String] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/IntersectTests.scala b/dataset/src/test/scala/frameless/forward/IntersectTests.scala
index f0edb856e..396356857 100644
--- a/dataset/src/test/scala/frameless/forward/IntersectTests.scala
+++ b/dataset/src/test/scala/frameless/forward/IntersectTests.scala
@@ -6,10 +6,14 @@ import math.Ordering
class IntersectTests extends TypedDatasetSuite {
test("intersect") {
- def prop[A: TypedEncoder : Ordering](data1: Vector[A], data2: Vector[A]): Prop = {
+ def prop[A: TypedEncoder: Ordering](
+ data1: Vector[A],
+ data2: Vector[A]
+ ): Prop = {
val dataset1 = TypedDataset.create(data1)
val dataset2 = TypedDataset.create(data2)
- val datasetIntersect = dataset1.intersect(dataset2).collect().run().toVector
+ val datasetIntersect =
+ dataset1.intersect(dataset2).collect().run().toVector
// Vector `intersect` is the multiset intersection, while Spark throws away duplicates.
val dataIntersect = data1.intersect(data2).distinct
diff --git a/dataset/src/test/scala/frameless/forward/IsLocalTests.scala b/dataset/src/test/scala/frameless/forward/IsLocalTests.scala
index f61d25cd1..71fbd27ce 100644
--- a/dataset/src/test/scala/frameless/forward/IsLocalTests.scala
+++ b/dataset/src/test/scala/frameless/forward/IsLocalTests.scala
@@ -14,4 +14,4 @@ class IsLocalTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala b/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala
index dd1874977..b056bc409 100644
--- a/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala
+++ b/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala
@@ -14,4 +14,4 @@ class IsStreamingTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala b/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala
index d59e250df..0a0d82465 100644
--- a/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala
+++ b/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala
@@ -1,7 +1,7 @@
package frameless
import org.scalacheck.Prop
-import org.scalacheck.Prop.{forAll, _}
+import org.scalacheck.Prop.{ forAll, _ }
class QueryExecutionTests extends TypedDatasetSuite {
test("queryExecution") {
@@ -14,4 +14,4 @@ class QueryExecutionTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala b/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala
index 4cc9a4fde..63ab904f0 100644
--- a/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala
+++ b/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala
@@ -2,36 +2,45 @@ package frameless
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Prop._
-import org.scalacheck.{Arbitrary, Gen}
+import org.scalacheck.{ Arbitrary, Gen }
import scala.collection.JavaConverters._
import org.scalatest.matchers.should.Matchers
class RandomSplitTests extends TypedDatasetSuite with Matchers {
- val nonEmptyPositiveArray: Gen[Array[Double]] = Gen.nonEmptyListOf(Gen.posNum[Double]).map(_.toArray)
+ val nonEmptyPositiveArray: Gen[Array[Double]] =
+ Gen.nonEmptyListOf(Gen.posNum[Double]).map(_.toArray)
test("randomSplit(weight, seed)") {
- def prop[A: TypedEncoder : Arbitrary] = forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) {
- (data: Vector[A], weights: Array[Double], seed: Long) =>
- val dataset = TypedDataset.create(data)
+ def prop[A: TypedEncoder: Arbitrary] =
+ forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) {
+ (data: Vector[A], weights: Array[Double], seed: Long) =>
+ val dataset = TypedDataset.create(data)
- dataset.randomSplit(weights, seed).map(_.count().run()) sameElements
- dataset.dataset.randomSplit(weights, seed).map(_.count())
- }
+ dataset.randomSplit(weights, seed).map(_.count().run()) sameElements
+ dataset.dataset.randomSplit(weights, seed).map(_.count())
+ }
check(prop[Int])
check(prop[String])
}
test("randomSplitAsList(weight, seed)") {
- def prop[A: TypedEncoder : Arbitrary] = forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) {
- (data: Vector[A], weights: Array[Double], seed: Long) =>
- val dataset = TypedDataset.create(data)
-
- dataset.randomSplitAsList(weights, seed).asScala.map(_.count().run()) sameElements
- dataset.dataset.randomSplitAsList(weights, seed).asScala.map(_.count())
- }
+ def prop[A: TypedEncoder: Arbitrary] =
+ forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) {
+ (data: Vector[A], weights: Array[Double], seed: Long) =>
+ val dataset = TypedDataset.create(data)
+
+ dataset
+ .randomSplitAsList(weights, seed)
+ .asScala
+ .map(_.count().run()) sameElements
+ dataset.dataset
+ .randomSplitAsList(weights, seed)
+ .asScala
+ .map(_.count())
+ }
check(prop[Int])
check(prop[String])
diff --git a/dataset/src/test/scala/frameless/forward/SQLContextTests.scala b/dataset/src/test/scala/frameless/forward/SQLContextTests.scala
index 700f29b05..7b524b076 100644
--- a/dataset/src/test/scala/frameless/forward/SQLContextTests.scala
+++ b/dataset/src/test/scala/frameless/forward/SQLContextTests.scala
@@ -1,7 +1,7 @@
package frameless
import org.scalacheck.Prop
-import org.scalacheck.Prop.{forAll, _}
+import org.scalacheck.Prop.{ forAll, _ }
class SQLContextTests extends TypedDatasetSuite {
test("sqlContext") {
diff --git a/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala b/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala
index c5d0da338..ce3130d3b 100644
--- a/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala
+++ b/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala
@@ -14,4 +14,4 @@ class SparkSessionTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala b/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala
index 3ac93773e..6b9c0dcd0 100644
--- a/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala
+++ b/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala
@@ -3,27 +3,41 @@ package frameless
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel._
import org.scalacheck.Prop._
-import org.scalacheck.{Arbitrary, Gen}
+import org.scalacheck.{ Arbitrary, Gen }
class StorageLevelTests extends TypedDatasetSuite {
- val storageLevelGen: Gen[StorageLevel] = Gen.oneOf(Seq(NONE, DISK_ONLY, DISK_ONLY_2, MEMORY_ONLY,
- MEMORY_ONLY_2, MEMORY_ONLY_SER, MEMORY_ONLY_SER_2, MEMORY_AND_DISK,
- MEMORY_AND_DISK_2, MEMORY_AND_DISK_SER, MEMORY_AND_DISK_SER_2, OFF_HEAP))
+ val storageLevelGen: Gen[StorageLevel] = Gen.oneOf(
+ Seq(
+ NONE,
+ DISK_ONLY,
+ DISK_ONLY_2,
+ MEMORY_ONLY,
+ MEMORY_ONLY_2,
+ MEMORY_ONLY_SER,
+ MEMORY_ONLY_SER_2,
+ MEMORY_AND_DISK,
+ MEMORY_AND_DISK_2,
+ MEMORY_AND_DISK_SER,
+ MEMORY_AND_DISK_SER_2,
+ OFF_HEAP
+ )
+ )
test("storageLevel") {
- def prop[A: TypedEncoder : Arbitrary] = forAll(vectorGen[A], storageLevelGen) {
- (data: Vector[A], storageLevel: StorageLevel) =>
- val dataset = TypedDataset.create(data)
- if (storageLevel != StorageLevel.NONE)
- dataset.persist(storageLevel)
+ def prop[A: TypedEncoder: Arbitrary] =
+ forAll(vectorGen[A], storageLevelGen) {
+ (data: Vector[A], storageLevel: StorageLevel) =>
+ val dataset = TypedDataset.create(data)
+ if (storageLevel != StorageLevel.NONE)
+ dataset.persist(storageLevel)
- dataset.count().run()
+ dataset.count().run()
- dataset.storageLevel() ?= dataset.dataset.storageLevel
- }
+ dataset.storageLevel() ?= dataset.dataset.storageLevel
+ }
check(prop[Int])
check(prop[String])
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/TakeTests.scala b/dataset/src/test/scala/frameless/forward/TakeTests.scala
index eec77bc80..a7f44a678 100644
--- a/dataset/src/test/scala/frameless/forward/TakeTests.scala
+++ b/dataset/src/test/scala/frameless/forward/TakeTests.scala
@@ -7,14 +7,22 @@ import scala.reflect.ClassTag
class TakeTests extends TypedDatasetSuite {
test("take") {
def prop[A: TypedEncoder](n: Int, data: Vector[A]): Prop =
- (n >= 0) ==> (TypedDataset.create(data).take(n).run().toVector =? data.take(n))
+ (n >= 0) ==> (TypedDataset.create(data).take(n).run().toVector =? data
+ .take(n))
- def propArray[A: TypedEncoder: ClassTag](n: Int, data: Vector[X1[Array[A]]]): Prop =
+ def propArray[A: TypedEncoder: ClassTag](
+ n: Int,
+ data: Vector[X1[Array[A]]]
+ ): Prop =
(n >= 0) ==> {
Prop {
- TypedDataset.create(data).take(n).run().toVector.zip(data.take(n)).forall {
- case (X1(l), X1(r)) => l sameElements r
- }
+ TypedDataset
+ .create(data)
+ .take(n)
+ .run()
+ .toVector
+ .zip(data.take(n))
+ .forall { case (X1(l), X1(r)) => l sameElements r }
}
}
diff --git a/dataset/src/test/scala/frameless/forward/ToJSONTests.scala b/dataset/src/test/scala/frameless/forward/ToJSONTests.scala
index 5ed79a9c9..5e78ea6d0 100644
--- a/dataset/src/test/scala/frameless/forward/ToJSONTests.scala
+++ b/dataset/src/test/scala/frameless/forward/ToJSONTests.scala
@@ -14,4 +14,4 @@ class ToJSONTests extends TypedDatasetSuite {
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala b/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala
index faaf25caf..436215008 100644
--- a/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala
+++ b/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala
@@ -10,7 +10,14 @@ class ToLocalIteratorTests extends TypedDatasetSuite with Matchers {
def prop[A: TypedEncoder](data: Vector[A]): Prop = {
val dataset = TypedDataset.create(data)
- dataset.toLocalIterator().run().asScala.toIterator sameElements dataset.dataset.toLocalIterator().asScala.toIterator
+ dataset
+ .toLocalIterator()
+ .run()
+ .asScala
+ .toIterator sameElements dataset.dataset
+ .toLocalIterator()
+ .asScala
+ .toIterator
}
check(forAll(prop[Int] _))
diff --git a/dataset/src/test/scala/frameless/forward/UnionTests.scala b/dataset/src/test/scala/frameless/forward/UnionTests.scala
index 6cd8f4005..b927f24f7 100644
--- a/dataset/src/test/scala/frameless/forward/UnionTests.scala
+++ b/dataset/src/test/scala/frameless/forward/UnionTests.scala
@@ -30,11 +30,21 @@ class UnionTests extends TypedDatasetSuite {
}
test("Align fields for case classes") {
- def prop[A: TypedEncoder, B: TypedEncoder](data1: Vector[(A, B)], data2: Vector[(A, B)]): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ data1: Vector[(A, B)],
+ data2: Vector[(A, B)]
+ ): Prop = {
val dataset1 = TypedDataset.create(data1.map((Foo.apply[A, B] _).tupled))
- val dataset2 = TypedDataset.create(data2.map { case (a, b) => Bar[A, B](b, a) })
- val datasetUnion = dataset1.union(dataset2).collect().run().map(foo => (foo.x, foo.y)).toVector
+ val dataset2 = TypedDataset.create(data2.map {
+ case (a, b) => Bar[A, B](b, a)
+ })
+ val datasetUnion = dataset1
+ .union(dataset2)
+ .collect()
+ .run()
+ .map(foo => (foo.x, foo.y))
+ .toVector
val dataUnion = data1 union data2
datasetUnion ?= dataUnion
@@ -45,11 +55,21 @@ class UnionTests extends TypedDatasetSuite {
}
test("Align fields for different number of columns") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](data1: Vector[(A, B, C)], data2: Vector[(A, B)]): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](
+ data1: Vector[(A, B, C)],
+ data2: Vector[(A, B)]
+ ): Prop = {
val dataset1 = TypedDataset.create(data2.map((Foo.apply[A, B] _).tupled))
- val dataset2 = TypedDataset.create(data1.map { case (a, b, c) => Baz[A, B, C](c, b, a) })
- val datasetUnion: Seq[(A, B)] = dataset1.union(dataset2).collect().run().map(foo => (foo.x, foo.y)).toVector
+ val dataset2 = TypedDataset.create(data1.map {
+ case (a, b, c) => Baz[A, B, C](c, b, a)
+ })
+ val datasetUnion: Seq[(A, B)] = dataset1
+ .union(dataset2)
+ .collect()
+ .run()
+ .map(foo => (foo.x, foo.y))
+ .toVector
val dataUnion = data2 union data1.map { case (a, b, _) => (a, b) }
datasetUnion ?= dataUnion
@@ -63,4 +83,4 @@ class UnionTests extends TypedDatasetSuite {
final case class Foo[A, B](x: A, y: B)
final case class Bar[A, B](y: B, x: A)
final case class Baz[A, B, C](z: C, y: B, x: A)
-final case class Wrong[A, B, C](a: A, b: B, c: C)
\ No newline at end of file
+final case class Wrong[A, B, C](a: A, b: B, c: C)
diff --git a/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala b/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala
index 368147c93..462e14e31 100644
--- a/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala
+++ b/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala
@@ -5,7 +5,7 @@ import java.util.UUID
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.scalacheck.Prop._
-import org.scalacheck.{Arbitrary, Gen, Prop}
+import org.scalacheck.{ Arbitrary, Gen, Prop }
class WriteStreamTests extends TypedDatasetSuite {
@@ -36,23 +36,34 @@ class WriteStreamTests extends TypedDatasetSuite {
val checkpointPath = s"$TEST_OUTPUT_DIR/checkpoint/$uid"
val inputStream = MemoryStream[A]
val input = TypedDataset.create(inputStream.toDS())
- val inputter = input.writeStream.format("csv").option("checkpointLocation", s"$checkpointPath/input").start(filePath)
+ val inputter = input.writeStream
+ .format("csv")
+ .option("checkpointLocation", s"$checkpointPath/input")
+ .start(filePath)
inputStream.addData(data)
inputter.processAllAvailable()
- val dataset = TypedDataset.createUnsafe(sqlContext.readStream.schema(input.schema).csv(filePath))
+ val dataset = TypedDataset.createUnsafe(
+ sqlContext.readStream.schema(input.schema).csv(filePath)
+ )
- val tester = dataset
- .writeStream
+ val tester = dataset.writeStream
.option("checkpointLocation", s"$checkpointPath/tester")
.format("memory")
.queryName(s"testCsv_$uidNoHyphens")
.start()
tester.processAllAvailable()
val output = spark.table(s"testCsv_$uidNoHyphens").as[A]
- TypedDataset.create(data).collect().run().groupBy(identity) ?= output.collect().groupBy(identity).map { case (k, arr) => (k, arr.toSeq) }
+ TypedDataset.create(data).collect().run().groupBy(identity) ?= output
+ .collect()
+ .groupBy(identity)
+ .map { case (k, arr) => (k, arr.toSeq) }
}
- check(forAll(Gen.nonEmptyListOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(prop[String]))
+ check(
+ forAll(Gen.nonEmptyListOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(
+ prop[String]
+ )
+ )
check(forAll(Gen.nonEmptyListOf(Arbitrary.arbitrary[Int]))(prop[Int]))
}
@@ -66,20 +77,27 @@ class WriteStreamTests extends TypedDatasetSuite {
val checkpointPath = s"$TEST_OUTPUT_DIR/checkpoint/$uid"
val inputStream = MemoryStream[A]
val input = TypedDataset.create(inputStream.toDS())
- val inputter = input.writeStream.format("parquet").option("checkpointLocation", s"$checkpointPath/input").start(filePath)
+ val inputter = input.writeStream
+ .format("parquet")
+ .option("checkpointLocation", s"$checkpointPath/input")
+ .start(filePath)
inputStream.addData(data)
inputter.processAllAvailable()
- val dataset = TypedDataset.createUnsafe(sqlContext.readStream.schema(input.schema).parquet(filePath))
+ val dataset = TypedDataset.createUnsafe(
+ sqlContext.readStream.schema(input.schema).parquet(filePath)
+ )
- val tester = dataset
- .writeStream
+ val tester = dataset.writeStream
.option("checkpointLocation", s"$checkpointPath/tester")
.format("memory")
.queryName(s"testParquet_$uidNoHyphens")
.start()
tester.processAllAvailable()
val output = spark.table(s"testParquet_$uidNoHyphens").as[A]
- TypedDataset.create(data).collect().run().groupBy(identity) ?= output.collect().groupBy(identity).map { case (k, arr) => (k, arr.toSeq) }
+ TypedDataset.create(data).collect().run().groupBy(identity) ?= output
+ .collect()
+ .groupBy(identity)
+ .map { case (k, arr) => (k, arr.toSeq) }
}
check(forAll(Gen.nonEmptyListOf(genWriteExample))(prop[WriteExample]))
diff --git a/dataset/src/test/scala/frameless/forward/WriteTests.scala b/dataset/src/test/scala/frameless/forward/WriteTests.scala
index d5a9057cb..4504935c3 100644
--- a/dataset/src/test/scala/frameless/forward/WriteTests.scala
+++ b/dataset/src/test/scala/frameless/forward/WriteTests.scala
@@ -3,7 +3,7 @@ package frameless
import java.util.UUID
import org.scalacheck.Prop._
-import org.scalacheck.{Arbitrary, Gen, Prop}
+import org.scalacheck.{ Arbitrary, Gen, Prop }
class WriteTests extends TypedDatasetSuite {
@@ -30,12 +30,19 @@ class WriteTests extends TypedDatasetSuite {
val input = TypedDataset.create(data)
input.write.csv(filePath)
- val dataset = TypedDataset.createUnsafe(sqlContext.read.schema(input.schema).csv(filePath))
+ val dataset = TypedDataset.createUnsafe(
+ sqlContext.read.schema(input.schema).csv(filePath)
+ )
- dataset.collect().run().groupBy(identity) ?= input.collect().run().groupBy(identity)
+ dataset.collect().run().groupBy(identity) ?= input
+ .collect()
+ .run()
+ .groupBy(identity)
}
- check(forAll(Gen.listOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(prop[String]))
+ check(
+ forAll(Gen.listOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(prop[String])
+ )
check(forAll(prop[Int] _))
}
@@ -45,9 +52,14 @@ class WriteTests extends TypedDatasetSuite {
val input = TypedDataset.create(data)
input.write.parquet(filePath)
- val dataset = TypedDataset.createUnsafe(sqlContext.read.schema(input.schema).parquet(filePath))
+ val dataset = TypedDataset.createUnsafe(
+ sqlContext.read.schema(input.schema).parquet(filePath)
+ )
- dataset.collect().run().groupBy(identity) ?= input.collect().run().groupBy(identity)
+ dataset.collect().run().groupBy(identity) ?= input
+ .collect()
+ .run()
+ .groupBy(identity)
}
check(forAll(Gen.listOf(genWriteExample))(prop[WriteExample]))
@@ -56,4 +68,9 @@ class WriteTests extends TypedDatasetSuite {
case class Nested(i: Double, v: String)
case class OptionFieldsOnly(o1: Option[Int], o2: Option[Nested])
-case class WriteExample(i: Int, s: String, on: Option[Nested], ooo: Option[OptionFieldsOnly])
+
+case class WriteExample(
+ i: Int,
+ s: String,
+ on: Option[Nested],
+ ooo: Option[OptionFieldsOnly])
diff --git a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala
index 201d93c63..a491b3816 100644
--- a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala
+++ b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala
@@ -1,31 +1,37 @@
package frameless
package functions
-import frameless.{TypedAggregate, TypedColumn}
+import frameless.{ TypedAggregate, TypedColumn }
import frameless.functions.aggregate._
-import org.apache.spark.sql.{Column, Encoder}
-import org.scalacheck.{Gen, Prop}
+import org.apache.spark.sql.{ Column, Encoder }
+import org.scalacheck.{ Gen, Prop }
import org.scalacheck.Prop._
import org.scalatest.exceptions.GeneratorDrivenPropertyCheckFailedException
class AggregateFunctionsTests extends TypedDatasetSuite {
- def sparkSchema[A: TypedEncoder, U](f: TypedColumn[X1[A], A] => TypedAggregate[X1[A], U]): Prop = {
+
+ def sparkSchema[A: TypedEncoder, U](
+ f: TypedColumn[X1[A], A] => TypedAggregate[X1[A], U]
+ ): Prop = {
val df = TypedDataset.create[X1[A]](Nil)
val col = f(df.col('a))
val sumDf = df.agg(col)
- TypedExpressionEncoder.targetStructType(sumDf.encoder) ?= sumDf.dataset.schema
+ TypedExpressionEncoder.targetStructType(
+ sumDf.encoder
+ ) ?= sumDf.dataset.schema
}
test("sum") {
case class Sum4Tests[A, B](sum: Seq[A] => B)
- def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])(
- implicit
- summable: CatalystSummable[A, Out],
- summer: Sum4Tests[A, Out]
- ): Prop = {
+ def prop[A: TypedEncoder, Out: TypedEncoder: Numeric](
+ xs: List[A]
+ )(implicit
+ summable: CatalystSummable[A, Out],
+ summer: Sum4Tests[A, Out]
+ ): Prop = {
val dataset = TypedDataset.create(xs.map(X1(_)))
val A = dataset.col[A]('a)
@@ -33,7 +39,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
datasetSum match {
case x :: Nil => approximatelyEqual(summer.sum(xs), x)
- case other => falsified
+ case other => falsified
}
}
@@ -61,27 +67,31 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
test("sumDistinct") {
case class Sum4Tests[A, B](sum: Seq[A] => B)
- def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])(
- implicit
- summable: CatalystSummable[A, Out],
- summer: Sum4Tests[A, Out]
- ): Prop = {
+ def prop[A: TypedEncoder, Out: TypedEncoder: Numeric](
+ xs: List[A]
+ )(implicit
+ summable: CatalystSummable[A, Out],
+ summer: Sum4Tests[A, Out]
+ ): Prop = {
val dataset = TypedDataset.create(xs.map(X1(_)))
val A = dataset.col[A]('a)
- val datasetSum: List[Out] = dataset.agg(sumDistinct(A)).collect().run().toList
+ val datasetSum: List[Out] =
+ dataset.agg(sumDistinct(A)).collect().run().toList
datasetSum match {
case x :: Nil => approximatelyEqual(summer.sum(xs), x)
- case other => falsified
+ case other => falsified
}
}
// Replicate Spark's behaviour : Ints and Shorts are cast to Long
// https://github.com/apache/spark/blob/7eb2ca8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L37
implicit def summerLong = Sum4Tests[Long, Long](_.toSet.sum)
- implicit def summerInt = Sum4Tests[Int, Long]( x => x.toSet.map((_:Int).toLong).sum)
- implicit def summerShort = Sum4Tests[Short, Long](x => x.toSet.map((_:Short).toLong).sum)
+ implicit def summerInt =
+ Sum4Tests[Int, Long](x => x.toSet.map((_: Int).toLong).sum)
+ implicit def summerShort =
+ Sum4Tests[Short, Long](x => x.toSet.map((_: Short).toLong).sum)
check(forAll(prop[Long, Long] _))
check(forAll(prop[Int, Long] _))
@@ -95,33 +105,41 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
test("avg") {
case class Averager4Tests[A, B](avg: Seq[A] => B)
- def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])(
- implicit
- averageable: CatalystAverageable[A, Out],
- averager: Averager4Tests[A, Out]
- ): Prop = {
+ def prop[A: TypedEncoder, Out: TypedEncoder: Numeric](
+ xs: List[A]
+ )(implicit
+ averageable: CatalystAverageable[A, Out],
+ averager: Averager4Tests[A, Out]
+ ): Prop = {
val dataset = TypedDataset.create(xs.map(X1(_)))
val A = dataset.col[A]('a)
val datasetAvg: Vector[Out] = dataset.agg(avg(A)).collect().run().toVector
if (datasetAvg.size > 2) falsified
- else xs match {
- case Nil => datasetAvg ?= Vector()
- case _ :: _ => datasetAvg.headOption match {
- case Some(x) => approximatelyEqual(averager.avg(xs), x)
- case None => falsified
+ else
+ xs match {
+ case Nil => datasetAvg ?= Vector()
+ case _ :: _ =>
+ datasetAvg.headOption match {
+ case Some(x) => approximatelyEqual(averager.avg(xs), x)
+ case None => falsified
+ }
}
- }
}
// Replicate Spark's behaviour : If the datatype isn't BigDecimal cast type to Double
// https://github.com/apache/spark/blob/7eb2ca8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L50
- implicit def averageDecimal = Averager4Tests[BigDecimal, BigDecimal](as => as.sum/as.size)
- implicit def averageDouble = Averager4Tests[Double, Double](as => as.sum/as.size)
- implicit def averageLong = Averager4Tests[Long, Double](as => as.map(_.toDouble).sum/as.size)
- implicit def averageInt = Averager4Tests[Int, Double](as => as.map(_.toDouble).sum/as.size)
- implicit def averageShort = Averager4Tests[Short, Double](as => as.map(_.toDouble).sum/as.size)
+ implicit def averageDecimal =
+ Averager4Tests[BigDecimal, BigDecimal](as => as.sum / as.size)
+ implicit def averageDouble =
+ Averager4Tests[Double, Double](as => as.sum / as.size)
+ implicit def averageLong =
+ Averager4Tests[Long, Double](as => as.map(_.toDouble).sum / as.size)
+ implicit def averageInt =
+ Averager4Tests[Int, Double](as => as.map(_.toDouble).sum / as.size)
+ implicit def averageShort =
+ Averager4Tests[Short, Double](as => as.map(_.toDouble).sum / as.size)
/* under 3.4 an oddity was detected:
Falsified after 2 successful property evaluations.
@@ -141,20 +159,27 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
}
test("stddev and variance") {
- def prop[A: TypedEncoder : CatalystVariance : Numeric](xs: List[A]): Prop = {
+ def prop[A: TypedEncoder: CatalystVariance: Numeric](xs: List[A]): Prop = {
val numeric = implicitly[Numeric[A]]
val dataset = TypedDataset.create(xs.map(X1(_)))
val A = dataset.col[A]('a)
- val datasetStdOpt = dataset.agg(stddev(A)).collect().run().toVector.headOption
- val datasetVarOpt = dataset.agg(variance(A)).collect().run().toVector.headOption
+ val datasetStdOpt =
+ dataset.agg(stddev(A)).collect().run().toVector.headOption
+ val datasetVarOpt =
+ dataset.agg(variance(A)).collect().run().toVector.headOption
- val std = sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleStdev()
- val `var` = sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleVariance()
+ val std =
+ sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleStdev()
+ val `var` =
+ sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleVariance()
(datasetStdOpt, datasetVarOpt) match {
case (Some(datasetStd), Some(datasetVar)) =>
- approximatelyEqual(datasetStd, std) && approximatelyEqual(datasetVar, `var`)
+ approximatelyEqual(datasetStd, std) && approximatelyEqual(
+ datasetVar,
+ `var`
+ )
case _ => proved
}
}
@@ -167,9 +192,17 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
}
test("litAggr") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](xs: List[A], b: B, c: C): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](
+ xs: List[A],
+ b: B,
+ c: C
+ ): Prop = {
val dataset = TypedDataset.create(xs)
- val (r1, rb, rc, rcount) = dataset.agg(count().lit(1), litAggr(b), litAggr(c), count()).collect().run().head
+ val (r1, rb, rc, rcount) = dataset
+ .agg(count().lit(1), litAggr(b), litAggr(c), count())
+ .collect()
+ .run()
+ .head
(rcount ?= xs.size.toLong) && (r1 ?= 1) && (rb ?= b) && (rc ?= c)
}
@@ -203,7 +236,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
}
test("max") {
- def prop[A: TypedEncoder: CatalystOrdered](xs: List[A])(implicit o: Ordering[A]): Prop = {
+ def prop[A: TypedEncoder: CatalystOrdered](
+ xs: List[A]
+ )(implicit
+ o: Ordering[A]
+ ): Prop = {
val dataset = TypedDataset.create(xs.map(X1(_)))
val A = dataset.col[A]('a)
val datasetMax = dataset.agg(max(A)).collect().run().toList
@@ -225,14 +262,18 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val A = dataset.col[Long]('a)
val datasetMax = dataset.agg(max(A) * 2).collect().run().headOption
- datasetMax ?= (if(xs.isEmpty) None else Some(xs.max * 2))
+ datasetMax ?= (if (xs.isEmpty) None else Some(xs.max * 2))
}
check(forAll(prop _))
}
test("min") {
- def prop[A: TypedEncoder: CatalystOrdered](xs: List[A])(implicit o: Ordering[A]): Prop = {
+ def prop[A: TypedEncoder: CatalystOrdered](
+ xs: List[A]
+ )(implicit
+ o: Ordering[A]
+ ): Prop = {
val dataset = TypedDataset.create(xs.map(X1(_)))
val A = dataset.col[A]('a)
@@ -301,8 +342,13 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
check {
forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] =>
val tds = TypedDataset.create(xs)
- val tdsRes: Seq[(Int, Long)] = tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run()
- tdsRes.toMap ?= xs.groupBy(_._1).mapValues(_.map(_._2).distinct.size.toLong).toSeq.toMap
+ val tdsRes: Seq[(Int, Long)] =
+ tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run()
+ tdsRes.toMap ?= xs
+ .groupBy(_._1)
+ .mapValues(_.map(_._2).distinct.size.toLong)
+ .toSeq
+ .toMap
}
}
}
@@ -310,7 +356,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
test("approxCountDistinct") {
// Simple version of #approximatelyEqual()
// Default maximum estimation error of HyperLogLog in Spark is 5%
- def approxEqual(actual: Long, estimated: Long, allowedDeviationPercentile: Double = 0.05): Boolean = {
+ def approxEqual(
+ actual: Long,
+ estimated: Long,
+ allowedDeviationPercentile: Double = 0.05
+ ): Boolean = {
val delta: Long = Math.abs(actual - estimated)
delta / actual.toDouble < allowedDeviationPercentile * 2
}
@@ -319,7 +369,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] =>
val tds = TypedDataset.create(xs)
val tdsRes: Seq[(Int, Long, Long)] =
- tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2))).collect().run()
+ tds
+ .groupBy(tds('_1))
+ .agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2)))
+ .collect()
+ .run()
tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2) }
}
}
@@ -329,18 +383,28 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val tds = TypedDataset.create(xs)
val allowedError = 0.1 // 10%
val tdsRes: Seq[(Int, Long, Long)] =
- tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2), allowedError)).collect().run()
+ tds
+ .groupBy(tds('_1))
+ .agg(
+ countDistinct(tds('_2)),
+ approxCountDistinct(tds('_2), allowedError)
+ )
+ .collect()
+ .run()
tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2, allowedError) }
}
}
}
test("collectList") {
- def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = {
+ def prop[A: TypedEncoder: Ordering](xs: List[X2[A, A]]): Prop = {
val tds = TypedDataset.create(xs)
- val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run()
+ val tdsRes: Seq[(A, Vector[A])] =
+ tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run()
- tdsRes.toMap.map { case (k, v) => k -> v.sorted } ?= xs.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).toVector.sorted }
+ tdsRes.toMap.map { case (k, v) => k -> v.sorted } ?= xs.groupBy(_.a).map {
+ case (k, v) => k -> v.map(_.b).toVector.sorted
+ }
}
check(forAll(prop[Long] _))
@@ -350,11 +414,14 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
}
test("collectSet") {
- def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = {
+ def prop[A: TypedEncoder: Ordering](xs: List[X2[A, A]]): Prop = {
val tds = TypedDataset.create(xs)
- val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run()
+ val tdsRes: Seq[(A, Vector[A])] =
+ tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run()
- tdsRes.toMap.map { case (k, v) => k -> v.toSet } ?= xs.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).toSet }
+ tdsRes.toMap.map { case (k, v) => k -> v.toSet } ?= xs.groupBy(_.a).map {
+ case (k, v) => k -> v.map(_.b).toSet
+ }
}
check(forAll(prop[Long] _))
@@ -379,77 +446,76 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[BigDecimal] _))
}
-
- def bivariatePropTemplate[A: TypedEncoder, B: TypedEncoder]
- (
- xs: List[X3[Int, A, B]]
- )
- (
- framelessFun: (TypedColumn[X3[Int, A, B], A], TypedColumn[X3[Int, A, B], B]) => TypedAggregate[X3[Int, A, B], Option[Double]],
- sparkFun: (Column, Column) => Column
- )
- (
- implicit
- encEv: Encoder[(Int, A, B)],
- encEv2: Encoder[(Int,Option[Double])],
- evCanBeDoubleA: CatalystCast[A, Double],
- evCanBeDoubleB: CatalystCast[B, Double]
- ): Prop = {
+ def bivariatePropTemplate[A: TypedEncoder, B: TypedEncoder](
+ xs: List[X3[Int, A, B]]
+ )(framelessFun: (
+ TypedColumn[X3[Int, A, B], A],
+ TypedColumn[X3[Int, A, B], B]
+ ) => TypedAggregate[X3[Int, A, B], Option[Double]],
+ sparkFun: (Column, Column) => Column
+ )(implicit
+ encEv: Encoder[(Int, A, B)],
+ encEv2: Encoder[(Int, Option[Double])],
+ evCanBeDoubleA: CatalystCast[A, Double],
+ evCanBeDoubleB: CatalystCast[B, Double]
+ ): Prop = {
val tds = TypedDataset.create(xs)
// Typed implementation of bivar stats function
- val tdBivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b), tds('c))).deserialized.map(kv =>
- (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))
- ).collect().run()
+ val tdBivar = tds
+ .groupBy(tds('a))
+ .agg(framelessFun(tds('b), tds('c)))
+ .deserialized
+ .map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)))
+ .collect()
+ .run()
val cDF = session.createDataset(xs.map(x => (x.a, x.b, x.c)))
// Comparison implementation of bivar stats functions
val compBivar = cDF
.groupBy(cDF("_1"))
.agg(sparkFun(cDF("_2"), cDF("_3")))
- .map(
- row => {
- val grp = row.getInt(0)
- (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1)))
- }
- )
+ .map(row => {
+ val grp = row.getInt(0)
+ (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1)))
+ })
// Should be the same
tdBivar.toMap ?= compBivar.collect().toMap
}
- def univariatePropTemplate[A: TypedEncoder]
- (
- xs: List[X2[Int, A]]
- )
- (
- framelessFun: (TypedColumn[X2[Int, A], A]) => TypedAggregate[X2[Int, A], Option[Double]],
- sparkFun: (Column) => Column
- )
- (
- implicit
- encEv: Encoder[(Int, A)],
- encEv2: Encoder[(Int,Option[Double])],
- evCanBeDoubleA: CatalystCast[A, Double]
- ): Prop = {
+ def univariatePropTemplate[A: TypedEncoder](
+ xs: List[X2[Int, A]]
+ )(framelessFun: (TypedColumn[X2[Int, A], A]) => TypedAggregate[
+ X2[Int, A],
+ Option[Double]
+ ],
+ sparkFun: (Column) => Column
+ )(implicit
+ encEv: Encoder[(Int, A)],
+ encEv2: Encoder[(Int, Option[Double])],
+ evCanBeDoubleA: CatalystCast[A, Double]
+ ): Prop = {
val tds = TypedDataset.create(xs)
- //typed implementation of univariate stats function
- val tdUnivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b))).deserialized.map(kv =>
- (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))
- ).collect().run()
+ // typed implementation of univariate stats function
+ val tdUnivar = tds
+ .groupBy(tds('a))
+ .agg(framelessFun(tds('b)))
+ .deserialized
+ .map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)))
+ .collect()
+ .run()
val cDF = session.createDataset(xs.map(x => (x.a, x.b)))
// Comparison implementation of bivar stats functions
val compUnivar = cDF
.groupBy(cDF("_1"))
.agg(sparkFun(cDF("_2")))
- .map(
- row => {
- val grp = row.getInt(0)
- (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1)))
- }
- )
+ .map(row => {
+ val grp = row.getInt(0)
+ (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1)))
+ })
// Should be the same
tdUnivar.toMap ?= compUnivar.collect().toMap
@@ -459,12 +525,16 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])(
- implicit
- encEv: Encoder[(Int, A, B)],
- evCanBeDoubleA: CatalystCast[A, Double],
- evCanBeDoubleB: CatalystCast[B, Double]
- ): Prop = bivariatePropTemplate(xs)(corr[A,B,X3[Int, A, B]],org.apache.spark.sql.functions.corr)
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ xs: List[X3[Int, A, B]]
+ )(implicit
+ encEv: Encoder[(Int, A, B)],
+ evCanBeDoubleA: CatalystCast[A, Double],
+ evCanBeDoubleB: CatalystCast[B, Double]
+ ): Prop = bivariatePropTemplate(xs)(
+ corr[A, B, X3[Int, A, B]],
+ org.apache.spark.sql.functions.corr
+ )
check(forAll(prop[Double, Double] _))
check(forAll(prop[Double, Int] _))
@@ -477,12 +547,13 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])(
- implicit
- encEv: Encoder[(Int, A, B)],
- evCanBeDoubleA: CatalystCast[A, Double],
- evCanBeDoubleB: CatalystCast[B, Double]
- ): Prop = bivariatePropTemplate(xs)(
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ xs: List[X3[Int, A, B]]
+ )(implicit
+ encEv: Encoder[(Int, A, B)],
+ evCanBeDoubleA: CatalystCast[A, Double],
+ evCanBeDoubleB: CatalystCast[B, Double]
+ ): Prop = bivariatePropTemplate(xs)(
covarPop[A, B, X3[Int, A, B]],
org.apache.spark.sql.functions.covar_pop
)
@@ -498,12 +569,13 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])(
- implicit
- encEv: Encoder[(Int, A, B)],
- evCanBeDoubleA: CatalystCast[A, Double],
- evCanBeDoubleB: CatalystCast[B, Double]
- ): Prop = bivariatePropTemplate(xs)(
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ xs: List[X3[Int, A, B]]
+ )(implicit
+ encEv: Encoder[(Int, A, B)],
+ evCanBeDoubleA: CatalystCast[A, Double],
+ evCanBeDoubleB: CatalystCast[B, Double]
+ ): Prop = bivariatePropTemplate(xs)(
covarSamp[A, B, X3[Int, A, B]],
org.apache.spark.sql.functions.covar_samp
)
@@ -519,11 +591,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder](xs: List[X2[Int, A]])(
- implicit
- encEv: Encoder[(Int, A)],
- evCanBeDoubleA: CatalystCast[A, Double]
- ): Prop = univariatePropTemplate(xs)(
+ def prop[A: TypedEncoder](
+ xs: List[X2[Int, A]]
+ )(implicit
+ encEv: Encoder[(Int, A)],
+ evCanBeDoubleA: CatalystCast[A, Double]
+ ): Prop = univariatePropTemplate(xs)(
kurtosis[A, X2[Int, A]],
org.apache.spark.sql.functions.kurtosis
)
@@ -539,11 +612,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder](xs: List[X2[Int, A]])(
- implicit
- encEv: Encoder[(Int, A)],
- evCanBeDoubleA: CatalystCast[A, Double]
- ): Prop = univariatePropTemplate(xs)(
+ def prop[A: TypedEncoder](
+ xs: List[X2[Int, A]]
+ )(implicit
+ encEv: Encoder[(Int, A)],
+ evCanBeDoubleA: CatalystCast[A, Double]
+ ): Prop = univariatePropTemplate(xs)(
skewness[A, X2[Int, A]],
org.apache.spark.sql.functions.skewness
)
@@ -559,11 +633,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder](xs: List[X2[Int, A]])(
- implicit
- encEv: Encoder[(Int, A)],
- evCanBeDoubleA: CatalystCast[A, Double]
- ): Prop = univariatePropTemplate(xs)(
+ def prop[A: TypedEncoder](
+ xs: List[X2[Int, A]]
+ )(implicit
+ encEv: Encoder[(Int, A)],
+ evCanBeDoubleA: CatalystCast[A, Double]
+ ): Prop = univariatePropTemplate(xs)(
stddevPop[A, X2[Int, A]],
org.apache.spark.sql.functions.stddev_pop
)
@@ -579,11 +654,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder](xs: List[X2[Int, A]])(
- implicit
- encEv: Encoder[(Int, A)],
- evCanBeDoubleA: CatalystCast[A, Double]
- ): Prop = univariatePropTemplate(xs)(
+ def prop[A: TypedEncoder](
+ xs: List[X2[Int, A]]
+ )(implicit
+ encEv: Encoder[(Int, A)],
+ evCanBeDoubleA: CatalystCast[A, Double]
+ ): Prop = univariatePropTemplate(xs)(
stddevSamp[A, X2[Int, A]],
org.apache.spark.sql.functions.stddev_samp
)
diff --git a/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala b/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala
index e22fe4337..42cccb2eb 100644
--- a/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala
+++ b/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala
@@ -3,8 +3,9 @@ package frameless.functions
import org.apache.spark.sql.Row
object DateTimeStringBehaviourUtils {
+
val nullHandler: Row => Option[Int] = _.get(0) match {
case i: Int => Some(i)
- case _ => None
+ case _ => None
}
}
diff --git a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala
index f3a8be581..d19b08ea1 100644
--- a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala
+++ b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala
@@ -2,19 +2,22 @@ package frameless
package functions
/**
- * Some statistical functions in Spark can result in Double, Double.NaN or Null.
- * This tends to break ?= of the property based testing. Use the nanNullHandler function
- * here to alleviate this by mapping this NaN and Null to None. This will result in
- * functioning comparison again.
- */
+ * Some statistical functions in Spark can result in Double, Double.NaN or Null.
+ * This tends to break ?= of the property based testing. Use the nanNullHandler function
+ * here to alleviate this by mapping this NaN and Null to None. This will result in
+ * functioning comparison again.
+ */
object DoubleBehaviourUtils {
+
// Mapping with this function is needed because spark uses Double.NaN for some semantics in the
// correlation function. ?= for prop testing will use == underlying and will break because Double.NaN != Double.NaN
- private val nanHandler: Double => Option[Double] = value => if (!value.equals(Double.NaN)) Option(value) else None
+ private val nanHandler: Double => Option[Double] = value =>
+ if (!value.equals(Double.NaN)) Option(value) else None
+
// Making sure that null => None and does not result in 0.0d because of row.getAs[Double]'s use of .asInstanceOf
val nanNullHandler: Any => Option[Double] = {
- case null => None
+ case null => None
case d: Double => nanHandler(d)
- case _ => ???
+ case _ => ???
}
}
diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala
index 470d58e5f..d6dcac14b 100644
--- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala
+++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala
@@ -7,9 +7,14 @@ import java.nio.charset.StandardCharsets
import frameless.functions.nonAggregate._
import org.apache.commons.io.FileUtils
-import org.apache.spark.sql.{Column, Encoder, SaveMode, functions => sparkFunctions}
+import org.apache.spark.sql.{
+ Column,
+ Encoder,
+ SaveMode,
+ functions => sparkFunctions
+}
import org.scalacheck.Prop._
-import org.scalacheck.{Arbitrary, Gen, Prop}
+import org.scalacheck.{ Arbitrary, Gen, Prop }
import scala.annotation.nowarn
@@ -17,30 +22,33 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val testTempFiles = "target/testoutput"
object NonNegativeGenerators {
+
val doubleGen = for {
- s <- Gen.chooseNum(1, Int.MaxValue)
- e <- Gen.chooseNum(1, Int.MaxValue)
+ s <- Gen.chooseNum(1, Int.MaxValue)
+ e <- Gen.chooseNum(1, Int.MaxValue)
res: Double = s.toDouble / e.toDouble
} yield res
- val intGen: Gen[Int] = Gen.chooseNum(1, Int.MaxValue)
+ val intGen: Gen[Int] = Gen.chooseNum(1, Int.MaxValue)
val shortGen: Gen[Short] = Gen.chooseNum(1, Short.MaxValue)
- val longGen: Gen[Long] = Gen.chooseNum(1, Long.MaxValue)
- val byteGen: Gen[Byte] = Gen.chooseNum(1, Byte.MaxValue)
+ val longGen: Gen[Long] = Gen.chooseNum(1, Long.MaxValue)
+ val byteGen: Gen[Byte] = Gen.chooseNum(1, Byte.MaxValue)
}
object NonNegativeArbitraryNumericValues {
import NonNegativeGenerators._
- implicit val arbInt: Arbitrary[Int] = Arbitrary(intGen)
- implicit val arbDouble: Arbitrary[Double] = Arbitrary(doubleGen)
- implicit val arbLong: Arbitrary[Long] = Arbitrary(longGen)
- implicit val arbShort: Arbitrary[Short] = Arbitrary(shortGen)
- implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen)
+ implicit val arbInt: Arbitrary[Int] = Arbitrary(intGen)
+ implicit val arbDouble: Arbitrary[Double] = Arbitrary(doubleGen)
+ implicit val arbLong: Arbitrary[Long] = Arbitrary(longGen)
+ implicit val arbShort: Arbitrary[Short] = Arbitrary(shortGen)
+ implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen)
}
private val base64Encoder = Base64.getEncoder
+
private def base64X1String(x1: X1[String]): X1[String] = {
- def base64(str: String): String = base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8))
+ def base64(str: String): String =
+ base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8))
x1.copy(a = base64(x1.a))
}
@@ -53,9 +61,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder](values: List[X1[A]])(
- implicit encX1:Encoder[X1[A]],
- catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B]) = {
+ def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]],
+ catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.negate(cDS("a")))
@@ -65,11 +76,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val typedDS = TypedDataset.create(values)
val col = typedDS('a)
- val res = typedDS
- .select(negate(col))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(negate(col)).collect().run().toList
res ?= resCompare
}
@@ -77,7 +84,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Byte, Byte] _))
check(forAll(prop[Short, Short] _))
check(forAll(prop[Int, Int] _))
- check(forAll(prop[Long, Long] _))
+ check(forAll(prop[Long, Long] _))
check(forAll(prop[BigDecimal, java.math.BigDecimal] _))
}
@@ -85,7 +92,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(values: List[X1[Boolean]], fromBase: Int, toBase: Int)(implicit encX1:Encoder[X1[Boolean]]) = {
+ def prop(
+ values: List[X1[Boolean]],
+ fromBase: Int,
+ toBase: Int
+ )(implicit
+ encX1: Encoder[X1[Boolean]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
@@ -96,11 +109,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val typedDS = TypedDataset.create(values)
val col = typedDS('a)
- val res = typedDS
- .select(not(col))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(not(col)).collect().run().toList
res ?= resCompare
}
@@ -112,7 +121,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(values: List[X1[String]], fromBase: Int, toBase: Int)(implicit encX1:Encoder[X1[String]]) = {
+ def prop(
+ values: List[X1[String]],
+ fromBase: Int,
+ toBase: Int
+ )(implicit
+ encX1: Encoder[X1[String]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
@@ -123,11 +138,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val typedDS = TypedDataset.create(values)
val col = typedDS('a)
- val res = typedDS
- .select(conv(col, fromBase, toBase))
- .collect()
- .run()
- .toList
+ val res =
+ typedDS.select(conv(col, fromBase, toBase)).collect().run().toList
res ?= resCompare
}
@@ -139,7 +151,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.degrees(cDS("a")))
@@ -149,11 +165,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val typedDS = TypedDataset.create(values)
val col = typedDS('a)
- val res = typedDS
- .select(degrees(col))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(degrees(col)).collect().run().toList
res ?= resCompare
}
@@ -161,12 +173,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Byte] _))
check(forAll(prop[Short] _))
check(forAll(prop[Int] _))
- check(forAll(prop[Long] _))
+ check(forAll(prop[Long] _))
check(forAll(prop[BigDecimal] _))
}
- def propBitShift[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]])
- (typedCol: TypedColumn[X1[A], B], sparkFunc: (Column,Int) => Column, numBits: Int): Prop = {
+ def propBitShift[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ typedDS: TypedDataset[X1[A]]
+ )(typedCol: TypedColumn[X1[A], B],
+ sparkFunc: (Column, Int) => Column,
+ numBits: Int
+ ): Prop = {
val spark = session
import spark.implicits._
@@ -176,11 +192,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.collect()
.toList
- val res = typedDS
- .select(typedCol)
- .collect()
- .run()
- .toList
+ val res = typedDS.select(typedCol).collect().run().toList
res ?= resCompare
}
@@ -190,11 +202,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
@nowarn // supress sparkFunctions.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat
- def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
- (values: List[X1[A]], numBits: Int)
- (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
+ def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ values: List[X1[A]],
+ numBits: Int
+ )(implicit
+ catalystBitShift: CatalystBitShift[A, B],
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
- propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftRightUnsigned, numBits)
+ propBitShift(typedDS)(
+ shiftRightUnsigned(typedDS('a), numBits),
+ sparkFunctions.shiftRightUnsigned,
+ numBits
+ )
}
check(forAll(prop[Byte, Int] _))
@@ -209,11 +229,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
@nowarn // supress sparkFunctions.shiftRight call which is used to maintain Spark 3.1.x backwards compat
- def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
- (values: List[X1[A]], numBits: Int)
- (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
+ def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ values: List[X1[A]],
+ numBits: Int
+ )(implicit
+ catalystBitShift: CatalystBitShift[A, B],
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
- propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftRight, numBits)
+ propBitShift(typedDS)(
+ shiftRight(typedDS('a), numBits),
+ sparkFunctions.shiftRight,
+ numBits
+ )
}
check(forAll(prop[Byte, Int] _))
@@ -228,11 +256,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
@nowarn // supress sparkFunctions.shiftLeft call which is used to maintain Spark 3.1.x backwards compat
- def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
- (values: List[X1[A]], numBits: Int)
- (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
+ def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ values: List[X1[A]],
+ numBits: Int
+ )(implicit
+ catalystBitShift: CatalystBitShift[A, B],
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
- propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftLeft, numBits)
+ propBitShift(typedDS)(
+ shiftLeft(typedDS('a), numBits),
+ sparkFunctions.shiftLeft,
+ numBits
+ )
}
check(forAll(prop[Byte, Int] _))
@@ -246,27 +282,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
- (values: List[X1[A]])(
- implicit catalystAbsolute: CatalystRound[A, B], encX1: Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystRound[A, B],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.ceil(cDS("a")))
.map(_.getAs[B](0))
.collect()
- .toList.map{
- case bigDecimal : java.math.BigDecimal => bigDecimal.setScale(0)
- case other => other
- }.asInstanceOf[List[B]]
-
+ .toList
+ .map {
+ case bigDecimal: java.math.BigDecimal => bigDecimal.setScale(0)
+ case other => other
+ }
+ .asInstanceOf[List[B]]
val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(ceil(typedDS('a)))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(ceil(typedDS('a))).collect().run().toList
res ?= resCompare
}
@@ -282,20 +317,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = {
+ def prop(
+ values: List[X1[Array[Byte]]]
+ )(implicit
+ encX1: Encoder[X1[Array[Byte]]]
+ ) = {
Seq(224, 256, 384, 512).map { numBits =>
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.sha2(cDS("a"), numBits))
.map(_.getAs[String](0))
- .collect().toList
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(sha2(typedDS('a), numBits))
.collect()
- .run()
.toList
+
+ val typedDS = TypedDataset.create(values)
+ val res =
+ typedDS.select(sha2(typedDS('a), numBits)).collect().run().toList
res ?= resCompare
}.reduce(_ && _)
}
@@ -307,20 +344,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = {
+ def prop(
+ values: List[X1[Array[Byte]]]
+ )(implicit
+ encX1: Encoder[X1[Array[Byte]]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.sha1(cDS("a")))
.map(_.getAs[String](0))
- .collect().toList
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(sha1(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(sha1(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -331,7 +369,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = {
+ def prop(
+ values: List[X1[Array[Byte]]]
+ )(implicit
+ encX1: Encoder[X1[Array[Byte]]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.crc32(cDS("a")))
@@ -340,11 +382,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.toList
val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(crc32(typedDS('a)))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(crc32(typedDS('a))).collect().run().toList
res ?= resCompare
}
@@ -356,27 +394,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
- (values: List[X1[A]])(
- implicit catalystAbsolute: CatalystRound[A, B], encX1: Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystRound[A, B],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.floor(cDS("a")))
.map(_.getAs[B](0))
.collect()
- .toList.map{
- case bigDecimal : java.math.BigDecimal => bigDecimal.setScale(0)
- case other => other
- }.asInstanceOf[List[B]]
-
+ .toList
+ .map {
+ case bigDecimal: java.math.BigDecimal => bigDecimal.setScale(0)
+ case other => other
+ }
+ .asInstanceOf[List[B]]
val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(floor(typedDS('a)))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(floor(typedDS('a))).collect().run().toList
res ?= resCompare
}
@@ -387,35 +424,35 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[BigDecimal, java.math.BigDecimal] _))
}
-
test("abs big decimal") {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]
- (values: List[X1[A]])
- (
- implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B],
- encX1:Encoder[X1[A]]
- )= {
- val cDS = session.createDataset(values)
- val resCompare = cDS
- .select(sparkFunctions.abs(cDS("a")))
- .map(_.getAs[B](0))
- .collect().toList
+ def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B],
+ encX1: Encoder[X1[A]]
+ ) = {
+ val cDS = session.createDataset(values)
+ val resCompare = cDS
+ .select(sparkFunctions.abs(cDS("a")))
+ .map(_.getAs[B](0))
+ .collect()
+ .toList
- val typedDS = TypedDataset.create(values)
- val col = typedDS('a)
- val res = typedDS
- .select(
- abs(col)
- )
- .collect()
- .run()
- .toList
+ val typedDS = TypedDataset.create(values)
+ val col = typedDS('a)
+ val res = typedDS
+ .select(
+ abs(col)
+ )
+ .collect()
+ .run()
+ .toList
- res ?= resCompare
- }
+ res ?= resCompare
+ }
check(forAll(prop[BigDecimal, java.math.BigDecimal] _))
}
@@ -424,26 +461,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder]
- (values: List[X1[A]])
- (
- implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, A],
- encX1: Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, A],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.abs(cDS("a")))
.map(_.getAs[A](0))
- .collect().toList
-
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(abs(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(abs(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -453,36 +486,43 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Double] _))
}
- def propTrigonometric[A: CatalystNumeric: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]])
- (typedCol: TypedColumn[X1[A], Double], sparkFunc: Column => Column): Prop = {
- val spark = session
- import spark.implicits._
+ def propTrigonometric[A: CatalystNumeric: TypedEncoder: Encoder](
+ typedDS: TypedDataset[X1[A]]
+ )(typedCol: TypedColumn[X1[A], Double],
+ sparkFunc: Column => Column
+ ): Prop = {
+ val spark = session
+ import spark.implicits._
- val resCompare = typedDS.dataset
- .select(sparkFunc($"a"))
- .map(_.getAs[Double](0))
- .map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ val resCompare = typedDS.dataset
+ .select(sparkFunc($"a"))
+ .map(_.getAs[Double](0))
+ .map(DoubleBehaviourUtils.nanNullHandler)
+ .collect()
+ .toList
- val res = typedDS
- .select(typedCol)
- .deserialized
- .map(DoubleBehaviourUtils.nanNullHandler)
- .collect()
- .run()
- .toList
+ val res = typedDS
+ .select(typedCol)
+ .deserialized
+ .map(DoubleBehaviourUtils.nanNullHandler)
+ .collect()
+ .run()
+ .toList
- res ?= resCompare
+ res ?= resCompare
}
test("cos") {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(cos(typedDS('a)), sparkFunctions.cos)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(cos(typedDS('a)), sparkFunctions.cos)
}
check(forAll(prop[Int] _))
@@ -497,10 +537,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(cosh(typedDS('a)), sparkFunctions.cosh)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(cosh(typedDS('a)), sparkFunctions.cosh)
}
check(forAll(prop[Int] _))
@@ -515,10 +558,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(acos(typedDS('a)), sparkFunctions.acos)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(acos(typedDS('a)), sparkFunctions.acos)
}
check(forAll(prop[Int] _))
@@ -529,16 +575,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Double] _))
}
-
-
test("signum") {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(signum(typedDS('a)), sparkFunctions.signum)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(signum(typedDS('a)), sparkFunctions.signum)
}
check(forAll(prop[Int] _))
@@ -553,10 +600,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(sin(typedDS('a)), sparkFunctions.sin)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(sin(typedDS('a)), sparkFunctions.sin)
}
check(forAll(prop[Int] _))
@@ -571,10 +621,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(sinh(typedDS('a)), sparkFunctions.sinh)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(sinh(typedDS('a)), sparkFunctions.sinh)
}
check(forAll(prop[Int] _))
@@ -589,10 +642,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(asin(typedDS('a)), sparkFunctions.asin)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(asin(typedDS('a)), sparkFunctions.asin)
}
check(forAll(prop[Int] _))
@@ -607,10 +663,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(tan(typedDS('a)), sparkFunctions.tan)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(tan(typedDS('a)), sparkFunctions.tan)
}
check(forAll(prop[Int] _))
@@ -625,10 +684,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])
- (implicit encX1:Encoder[X1[A]]) = {
- val typedDS = TypedDataset.create(values)
- propTrigonometric(typedDS)(tanh(typedDS('a)), sparkFunctions.tanh)
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
+ val typedDS = TypedDataset.create(values)
+ propTrigonometric(typedDS)(tanh(typedDS('a)), sparkFunctions.tanh)
}
check(forAll(prop[Int] _))
@@ -639,48 +701,46 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Double] _))
}
- /*
- * Currently not all Collection types play nice with the Encoders.
- * This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved.
- *
- * [[https://issues.apache.org/jira/browse/SPARK-18891]]
- * [[https://issues.apache.org/jira/browse/SPARK-21204]]
- */
- test("arrayContains"){
+ /*
+ * Currently not all Collection types play nice with the Encoders.
+ * This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved.
+ *
+ * [[https://issues.apache.org/jira/browse/SPARK-18891]]
+ * [[https://issues.apache.org/jira/browse/SPARK-21204]]
+ */
+ test("arrayContains") {
val spark = session
import spark.implicits._
val listLength = 10
val idxs = Stream.continually(Range(0, listLength)).flatten.toIterator
- abstract class Nth[A, C[A]:CatalystCollection] {
+ abstract class Nth[A, C[A]: CatalystCollection] {
- def nth(c:C[A], idx:Int):A
+ def nth(c: C[A], idx: Int): A
}
- implicit def deriveListNth[A] : Nth[A, List] = new Nth[A, List] {
+ implicit def deriveListNth[A]: Nth[A, List] = new Nth[A, List] {
override def nth(c: List[A], idx: Int): A = c(idx)
}
- implicit def deriveSeqNth[A] : Nth[A, Seq] = new Nth[A, Seq] {
+ implicit def deriveSeqNth[A]: Nth[A, Seq] = new Nth[A, Seq] {
override def nth(c: Seq[A], idx: Int): A = c(idx)
}
- implicit def deriveVectorNth[A] : Nth[A, Vector] = new Nth[A, Vector] {
+ implicit def deriveVectorNth[A]: Nth[A, Vector] = new Nth[A, Vector] {
override def nth(c: Vector[A], idx: Int): A = c(idx)
}
- implicit def deriveArrayNth[A] : Nth[A, Array] = new Nth[A, Array] {
+ implicit def deriveArrayNth[A]: Nth[A, Array] = new Nth[A, Array] {
override def nth(c: Array[A], idx: Int): A = c(idx)
}
-
- def prop[C[_] : CatalystCollection]
- (
+ def prop[C[_]: CatalystCollection](
values: C[Int],
- shouldBeIn:Boolean)
- (
- implicit nth:Nth[Int, C],
+ shouldBeIn: Boolean
+ )(implicit
+ nth: Nth[Int, C],
encEv: Encoder[C[Int]],
tEncEv: TypedEncoder[C[Int]]
) = {
@@ -691,7 +751,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val resCompare = cDS
.select(sparkFunctions.array_contains(cDS("value"), contained))
.map(_.getAs[Boolean](0))
- .collect().toList
+ .collect()
+ .toList
val typedDS = TypedDataset.create(List(X1(values)))
val res = typedDS
@@ -705,10 +766,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(
forAll(
- Gen.listOfN(listLength, Gen.choose(0,100)),
- Gen.oneOf(true,false)
- )
- (prop[List])
+ Gen.listOfN(listLength, Gen.choose(0, 100)),
+ Gen.oneOf(true, false)
+ )(prop[List])
)
/*check( Looks like there is no Typed Encoder for Seq type yet
@@ -721,18 +781,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(
forAll(
- Gen.listOfN(listLength, Gen.choose(0,100)).map(_.toVector),
- Gen.oneOf(true,false)
- )
- (prop[Vector])
+ Gen.listOfN(listLength, Gen.choose(0, 100)).map(_.toVector),
+ Gen.oneOf(true, false)
+ )(prop[Vector])
)
check(
forAll(
- Gen.listOfN(listLength, Gen.choose(0,100)).map(_.toArray),
- Gen.oneOf(true,false)
- )
- (prop[Array])
+ Gen.listOfN(listLength, Gen.choose(0, 100)).map(_.toArray),
+ Gen.oneOf(true, false)
+ )(prop[Array])
)
}
@@ -740,14 +798,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder]
- (na: A, values: List[X1[A]])(implicit encX1: Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ na: A,
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(X1(na) :: values)
val resCompare = cDS
.select(sparkFunctions.atan(cDS("a")))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val typedDS = TypedDataset.create(cDS)
val res = typedDS
@@ -758,13 +821,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.run()
.toList
- val aggrTyped = typedDS.agg(atan(
- frameless.functions.aggregate.first(typedDS('a)))
- ).firstOption().run().get
+ val aggrTyped = typedDS
+ .agg(atan(frameless.functions.aggregate.first(typedDS('a))))
+ .firstOption()
+ .run()
+ .get
- val aggrSpark = cDS.select(
- sparkFunctions.atan(sparkFunctions.first("a")).as[Double]
- ).first()
+ val aggrSpark = cDS
+ .select(
+ sparkFunctions.atan(sparkFunctions.first("a")).as[Double]
+ )
+ .first()
(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
}
@@ -781,16 +848,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder,
- B: CatalystNumeric : TypedEncoder : Encoder](na: X2[A, B], values: List[X2[A, B]])
- (implicit encEv: Encoder[X2[A,B]]) = {
+ def prop[
+ A: CatalystNumeric: TypedEncoder: Encoder,
+ B: CatalystNumeric: TypedEncoder: Encoder
+ ](na: X2[A, B],
+ values: List[X2[A, B]]
+ )(implicit
+ encEv: Encoder[X2[A, B]]
+ ) = {
val cDS = session.createDataset(na +: values)
val resCompare = cDS
.select(sparkFunctions.atan2(cDS("a"), cDS("b")))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
-
+ .collect()
+ .toList
val typedDS = TypedDataset.create(cDS)
val res = typedDS
@@ -801,19 +873,28 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.run()
.toList
- val aggrTyped = typedDS.agg(atan2(
- frameless.functions.aggregate.first(typedDS('a)),
- frameless.functions.aggregate.first(typedDS('b)))
- ).firstOption().run().get
+ val aggrTyped = typedDS
+ .agg(
+ atan2(
+ frameless.functions.aggregate.first(typedDS('a)),
+ frameless.functions.aggregate.first(typedDS('b))
+ )
+ )
+ .firstOption()
+ .run()
+ .get
- val aggrSpark = cDS.select(
- sparkFunctions.atan2(sparkFunctions.first("a"),sparkFunctions.first("b")).as[Double]
- ).first()
+ val aggrSpark = cDS
+ .select(
+ sparkFunctions
+ .atan2(sparkFunctions.first("a"), sparkFunctions.first("b"))
+ .as[Double]
+ )
+ .first()
(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
}
-
check(forAll(prop[Int, Long] _))
check(forAll(prop[Long, Int] _))
check(forAll(prop[Short, Byte] _))
@@ -826,15 +907,20 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder]
- (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ na: X1[A],
+ value: List[X1[A]],
+ lit: Double
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(na +: value)
val resCompare = cDS
.select(sparkFunctions.atan2(lit, cDS("a")))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
-
+ .collect()
+ .toList
val typedDS = TypedDataset.create(cDS)
val res = typedDS
@@ -845,14 +931,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.run()
.toList
- val aggrTyped = typedDS.agg(atan2(
- lit,
- frameless.functions.aggregate.first(typedDS('a)))
- ).firstOption().run().get
+ val aggrTyped = typedDS
+ .agg(atan2(lit, frameless.functions.aggregate.first(typedDS('a))))
+ .firstOption()
+ .run()
+ .get
- val aggrSpark = cDS.select(
- sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double]
- ).first()
+ val aggrSpark = cDS
+ .select(
+ sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double]
+ )
+ .first()
(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
}
@@ -869,15 +958,20 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder]
- (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ na: X1[A],
+ value: List[X1[A]],
+ lit: Double
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(na +: value)
val resCompare = cDS
.select(sparkFunctions.atan2(cDS("a"), lit))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
-
+ .collect()
+ .toList
val typedDS = TypedDataset.create(cDS)
val res = typedDS
@@ -888,19 +982,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.run()
.toList
- val aggrTyped = typedDS.agg(atan2(
- frameless.functions.aggregate.first(typedDS('a)),
- lit)
- ).firstOption().run().get
+ val aggrTyped = typedDS
+ .agg(atan2(frameless.functions.aggregate.first(typedDS('a)), lit))
+ .firstOption()
+ .run()
+ .get
- val aggrSpark = cDS.select(
- sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double]
- ).first()
+ val aggrSpark = cDS
+ .select(
+ sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double]
+ )
+ .first()
(res ?= resCompare).&&(aggrTyped ?= aggrSpark)
}
-
check(forAll(prop[Int] _))
check(forAll(prop[Long] _))
check(forAll(prop[Short] _))
@@ -909,9 +1005,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Double] _))
}
- def mathProp[A: CatalystNumeric: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]])(
- typedCol: TypedColumn[X1[A], Double], sparkFunc: Column => Column
- ): Prop = {
+ def mathProp[A: CatalystNumeric: TypedEncoder: Encoder](
+ typedDS: TypedDataset[X1[A]]
+ )(typedCol: TypedColumn[X1[A], Double],
+ sparkFunc: Column => Column
+ ): Prop = {
val spark = session
import spark.implicits._
@@ -919,7 +1017,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.select(sparkFunc($"a"))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val res = typedDS
.select(typedCol)
@@ -936,7 +1035,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
mathProp(typedDS)(sqrt(typedDS('a)), sparkFunctions.sqrt)
}
@@ -953,7 +1056,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
mathProp(typedDS)(cbrt(typedDS('a)), sparkFunctions.cbrt)
}
@@ -970,7 +1077,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
mathProp(typedDS)(exp(typedDS('a)), sparkFunctions.exp)
}
@@ -987,7 +1098,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder](values: List[X1[A]]): Prop = {
+ def prop[A: TypedEncoder: Encoder](values: List[X1[A]]): Prop = {
val spark = session
import spark.implicits._
@@ -996,14 +1107,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val resCompare = typedDS.dataset
.select(sparkFunctions.md5($"a"))
.map(_.getAs[String](0))
- .collect().toList
-
- val res = typedDS
- .select(md5(typedDS('a)))
.collect()
- .run()
.toList
+ val res = typedDS.select(md5(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -1022,14 +1130,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val resCompare = typedDS.dataset
.select(sparkFunctions.factorial($"a"))
.map(_.getAs[Long](0))
- .collect().toList
-
- val res = typedDS
- .select(factorial(typedDS('a)))
.collect()
- .run()
.toList
+ val res = typedDS.select(factorial(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -1040,24 +1145,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder](values: List[X1[A]])(
- implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A],
- encX1: Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[
+ A,
+ A
+ ],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.round(cDS("a")))
.map(_.getAs[A](0))
- .collect().toList
-
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(round(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(round(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -1071,25 +1177,27 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder: Encoder](values: List[X1[A]])(
- implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal],
- encX1:Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystNumericWithJavaBigDecimal[
+ A,
+ java.math.BigDecimal
+ ],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.round(cDS("a")))
.map(_.getAs[java.math.BigDecimal](0))
.collect()
- .toList.map(_.setScale(0))
+ .toList
+ .map(_.setScale(0))
val typedDS = TypedDataset.create(values)
val col = typedDS('a)
- val res = typedDS
- .select(round(col))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(round(col)).collect().run().toList
res ?= resCompare
}
@@ -1101,24 +1209,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder](values: List[X1[A]])(
- implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A],
- encX1: Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[
+ A,
+ A
+ ],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.round(cDS("a"), 1))
.map(_.getAs[A](0))
- .collect().toList
-
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(round(typedDS('a), 1))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(round(typedDS('a), 1)).collect().run().toList
+
res ?= resCompare
}
@@ -1132,25 +1241,27 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder: Encoder](values: List[X1[A]])(
- implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal],
- encX1:Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystNumericWithJavaBigDecimal[
+ A,
+ java.math.BigDecimal
+ ],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.round(cDS("a"), 0))
.map(_.getAs[java.math.BigDecimal](0))
.collect()
- .toList.map(_.setScale(0))
+ .toList
+ .map(_.setScale(0))
val typedDS = TypedDataset.create(values)
val col = typedDS('a)
- val res = typedDS
- .select(round(col, 0))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(round(col, 0)).collect().run().toList
res ?= resCompare
}
@@ -1162,24 +1273,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder : Encoder](values: List[X1[A]])(
- implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A],
- encX1: Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[
+ A,
+ A
+ ],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.bround(cDS("a")))
.map(_.getAs[A](0))
- .collect().toList
-
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(bround(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(bround(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -1187,31 +1299,33 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Long] _))
check(forAll(prop[Short] _))
check(forAll(prop[Double] _))
- }
+ }
test("bround big decimal") {
val spark = session
import spark.implicits._
- def prop[A: TypedEncoder: Encoder](values: List[X1[A]])(
- implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal],
- encX1:Encoder[X1[A]]
- ) = {
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystNumericWithJavaBigDecimal[
+ A,
+ java.math.BigDecimal
+ ],
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.bround(cDS("a")))
.map(_.getAs[java.math.BigDecimal](0))
.collect()
- .toList.map(_.setScale(0))
+ .toList
+ .map(_.setScale(0))
val typedDS = TypedDataset.create(values)
val col = typedDS('a)
- val res = typedDS
- .select(bround(col))
- .collect()
- .run()
- .toList
+ val res = typedDS.select(bround(col)).collect().run().toList
res ?= resCompare
}
@@ -1219,63 +1333,66 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[BigDecimal] _))
}
- test("bround with scale") {
- val spark = session
- import spark.implicits._
+ test("bround with scale") {
+ val spark = session
+ import spark.implicits._
- def prop[A: TypedEncoder : Encoder](values: List[X1[A]])(
- implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A],
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[
+ A,
+ A
+ ],
encX1: Encoder[X1[A]]
) = {
- val cDS = session.createDataset(values)
- val resCompare = cDS
- .select(sparkFunctions.bround(cDS("a"), 1))
- .map(_.getAs[A](0))
- .collect().toList
-
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(bround(typedDS('a), 1))
- .collect()
- .run()
- .toList
+ val cDS = session.createDataset(values)
+ val resCompare = cDS
+ .select(sparkFunctions.bround(cDS("a"), 1))
+ .map(_.getAs[A](0))
+ .collect()
+ .toList
- res ?= resCompare
- }
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(bround(typedDS('a), 1)).collect().run().toList
- check(forAll(prop[Int] _))
- check(forAll(prop[Long] _))
- check(forAll(prop[Short] _))
- check(forAll(prop[Double] _))
+ res ?= resCompare
}
- test("bround big decimal with scale") {
- val spark = session
- import spark.implicits._
+ check(forAll(prop[Int] _))
+ check(forAll(prop[Long] _))
+ check(forAll(prop[Short] _))
+ check(forAll(prop[Double] _))
+ }
- def prop[A: TypedEncoder: Encoder](values: List[X1[A]])(
- implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal],
- encX1:Encoder[X1[A]]
+ test("bround big decimal with scale") {
+ val spark = session
+ import spark.implicits._
+
+ def prop[A: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ catalystAbsolute: CatalystNumericWithJavaBigDecimal[
+ A,
+ java.math.BigDecimal
+ ],
+ encX1: Encoder[X1[A]]
) = {
- val cDS = session.createDataset(values)
-
- val resCompare = cDS
- .select(sparkFunctions.bround(cDS("a"), 0))
- .map(_.getAs[java.math.BigDecimal](0))
- .collect()
- .toList.map(_.setScale(0))
-
- val typedDS = TypedDataset.create(values)
- val col = typedDS('a)
- val res = typedDS
- .select(bround(col, 0))
- .collect()
- .run()
- .toList
-
- res ?= resCompare
- }
+ val cDS = session.createDataset(values)
+
+ val resCompare = cDS
+ .select(sparkFunctions.bround(cDS("a"), 0))
+ .map(_.getAs[java.math.BigDecimal](0))
+ .collect()
+ .toList
+ .map(_.setScale(0))
+
+ val typedDS = TypedDataset.create(values)
+ val col = typedDS('a)
+ val res = typedDS.select(bround(col, 0)).collect().run().toList
+
+ res ?= resCompare
+ }
check(forAll(prop[BigDecimal] _))
}
@@ -1285,10 +1402,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
import NonNegativeArbitraryNumericValues._
- def prop[A: CatalystNumeric: TypedEncoder : Encoder](
- values: List[X1[A]],
- base: Double
- ): Prop = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]],
+ base: Double
+ ): Prop = {
val spark = session
import spark.implicits._
val typedDS = TypedDataset.create(values)
@@ -1297,7 +1414,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.select(sparkFunctions.log(base, $"a"))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val res = typedDS
.select(log(base, typedDS('a)))
@@ -1322,7 +1440,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
import NonNegativeArbitraryNumericValues._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
mathProp(typedDS)(log(typedDS('a)), sparkFunctions.log)
}
@@ -1339,7 +1461,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
import NonNegativeArbitraryNumericValues._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
mathProp(typedDS)(log2(typedDS('a)), sparkFunctions.log2)
}
@@ -1356,7 +1482,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
import NonNegativeArbitraryNumericValues._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
mathProp(typedDS)(log1p(typedDS('a)), sparkFunctions.log1p)
}
@@ -1373,7 +1503,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
import NonNegativeArbitraryNumericValues._
- def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val typedDS = TypedDataset.create(values)
mathProp(typedDS)(log10(typedDS('a)), sparkFunctions.log10)
}
@@ -1389,20 +1523,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(values:List[X1[Array[Byte]]])(implicit encX1:Encoder[X1[Array[Byte]]]) = {
+ def prop(
+ values: List[X1[Array[Byte]]]
+ )(implicit
+ encX1: Encoder[X1[Array[Byte]]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.base64(cDS("a")))
.map(_.getAs[String](0))
- .collect().toList
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(base64(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(base64(typedDS('a))).collect().run().toList
+
val backAndForth = typedDS
.select(base64(unbase64(base64(typedDS('a)))))
.collect()
@@ -1419,10 +1554,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric: TypedEncoder : Encoder](
- values: List[X1[A]],
- base: Double
- ): Prop = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]],
+ base: Double
+ ): Prop = {
val spark = session
import spark.implicits._
val typedDS = TypedDataset.create(values)
@@ -1431,7 +1566,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.select(sparkFunctions.hypot(base, $"a"))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val res2 = typedDS
.select(hypot(typedDS('a), base))
@@ -1463,9 +1599,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric: TypedEncoder : Encoder](
- values: List[X2[A, A]]
- ): Prop = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X2[A, A]]
+ ): Prop = {
val spark = session
import spark.implicits._
val typedDS = TypedDataset.create(values)
@@ -1474,7 +1610,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.select(sparkFunctions.hypot($"b", $"a"))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val res = typedDS
.select(hypot(typedDS('b), typedDS('a)))
@@ -1498,10 +1635,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric: TypedEncoder : Encoder](
- values: List[X1[A]],
- base: Double
- ): Prop = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X1[A]],
+ base: Double
+ ): Prop = {
val spark = session
import spark.implicits._
val typedDS = TypedDataset.create(values)
@@ -1510,7 +1647,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.select(sparkFunctions.pow(base, $"a"))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val res = typedDS
.select(pow(base, typedDS('a)))
@@ -1524,7 +1662,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.select(sparkFunctions.pow($"a", base))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val res2 = typedDS
.select(pow(typedDS('a), base))
@@ -1534,7 +1673,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.run()
.toList
- (res ?= resCompare) && (res2 ?= resCompare2)
+ (res ?= resCompare) && (res2 ?= resCompare2)
}
check(forAll(prop[Int] _))
@@ -1548,9 +1687,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A: CatalystNumeric: TypedEncoder : Encoder](
- values: List[X2[A, A]]
- ): Prop = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X2[A, A]]
+ ): Prop = {
val spark = session
import spark.implicits._
val typedDS = TypedDataset.create(values)
@@ -1559,7 +1698,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.select(sparkFunctions.pow($"b", $"a"))
.map(_.getAs[Double](0))
.map(DoubleBehaviourUtils.nanNullHandler)
- .collect().toList
+ .collect()
+ .toList
val res = typedDS
.select(pow(typedDS('b), typedDS('a)))
@@ -1584,9 +1724,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
import spark.implicits._
import NonNegativeArbitraryNumericValues._
- def prop[A: CatalystNumeric: TypedEncoder : Encoder](
- values: List[X2[A, A]]
- ): Prop = {
+ def prop[A: CatalystNumeric: TypedEncoder: Encoder](
+ values: List[X2[A, A]]
+ ): Prop = {
val spark = session
import spark.implicits._
val typedDS = TypedDataset.create(values)
@@ -1594,14 +1734,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val resCompare = typedDS.dataset
.select(sparkFunctions.pmod($"b", $"a"))
.map(_.getAs[A](0))
- .collect().toList
-
- val res = typedDS
- .select(pmod(typedDS('b), typedDS('a)))
.collect()
- .run()
.toList
+ val res =
+ typedDS.select(pmod(typedDS('b), typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -1616,71 +1754,73 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(values: List[X1[String]])(implicit encX1: Encoder[X1[String]]) = {
+ def prop(
+ values: List[X1[String]]
+ )(implicit
+ encX1: Encoder[X1[String]]
+ ) = {
val valuesBase64 = values.map(base64X1String)
val cDS = session.createDataset(valuesBase64)
val resCompare = cDS
.select(sparkFunctions.unbase64(cDS("a")))
.map(_.getAs[Array[Byte]](0))
- .collect().toList
-
- val typedDS = TypedDataset.create(valuesBase64)
- val res = typedDS
- .select(unbase64(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(valuesBase64)
+ val res = typedDS.select(unbase64(typedDS('a))).collect().run().toList
+
res.map(_.toList) ?= resCompare.map(_.toList)
}
check(forAll(prop _))
}
- test("bin"){
+ test("bin") {
val spark = session
import spark.implicits._
- def prop(values:List[X1[Long]])(implicit encX1:Encoder[X1[Long]]) = {
+ def prop(
+ values: List[X1[Long]]
+ )(implicit
+ encX1: Encoder[X1[Long]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.bin(cDS("a")))
.map(_.getAs[String](0))
- .collect().toList
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(bin(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(bin(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
check(forAll(prop _))
}
- test("bitwiseNOT"){
+ test("bitwiseNOT") {
val spark = session
import spark.implicits._
@nowarn // supress sparkFunctions.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat
- def prop[A: CatalystBitwise : TypedEncoder : Encoder]
- (values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
+ def prop[A: CatalystBitwise: TypedEncoder: Encoder](
+ values: List[X1[A]]
+ )(implicit
+ encX1: Encoder[X1[A]]
+ ) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.bitwiseNOT(cDS("a")))
.map(_.getAs[A](0))
- .collect().toList
-
- val typedDS = TypedDataset.create(values)
- val res = typedDS
- .select(bitwiseNOT(typedDS('a)))
.collect()
- .run()
.toList
+ val typedDS = TypedDataset.create(values)
+ val res = typedDS.select(bitwiseNOT(typedDS('a))).collect().run().toList
+
res ?= resCompare
}
@@ -1694,11 +1834,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A : TypedEncoder](
- toFile1: List[X1[A]],
- toFile2: List[X1[A]],
- inMem: List[X1[A]]
- )(implicit x2Gen: Encoder[X2[A, String]], x3Gen: Encoder[X3[A, String, String]]) = {
+ def prop[A: TypedEncoder](
+ toFile1: List[X1[A]],
+ toFile2: List[X1[A]],
+ inMem: List[X1[A]]
+ )(implicit
+ x2Gen: Encoder[X2[A, String]],
+ x3Gen: Encoder[X3[A, String, String]]
+ ) = {
val file1Path = testTempFiles + "/file1"
val file2Path = testTempFiles + "/file2"
@@ -1719,7 +1862,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val unioned = ds1.union(ds2).union(ds3)
- val withFileName = unioned.withColumn[X3[A, String, String]](inputFileName[X2[A, String]]())
+ val withFileName = unioned
+ .withColumn[X3[A, String, String]](inputFileName[X2[A, String]]())
.collect()
.run()
.toVector
@@ -1727,10 +1871,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val grouped = withFileName.groupBy(_.b).mapValues(_.map(_.c).toSet)
grouped.foldLeft(passed) { (p, g) =>
- p && secure { g._1 match {
- case "" => g._2.head == "" //Empty string if didn't come from file
- case f => g._2.forall(_.contains(f))
- }}}
+ p && secure {
+ g._1 match {
+ case "" => g._2.head == "" // Empty string if didn't come from file
+ case f => g._2.forall(_.contains(f))
+ }
+ }
+ }
}
check(forAll(prop[String] _))
@@ -1740,17 +1887,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A : TypedEncoder](xs: List[X1[A]])(implicit x2en: Encoder[X2[A, Long]]) = {
+ def prop[A: TypedEncoder](
+ xs: List[X1[A]]
+ )(implicit
+ x2en: Encoder[X2[A, Long]]
+ ) = {
val ds = TypedDataset.create(xs)
- val result = ds.withColumn[X2[A, Long]](monotonicallyIncreasingId())
+ val result = ds
+ .withColumn[X2[A, Long]](monotonicallyIncreasingId())
.collect()
.run()
.toVector
val ids = result.map(_.b)
(ids.toSet.size ?= ids.length) &&
- (ids.sorted ?= ids)
+ (ids.sorted ?= ids)
}
check(forAll(prop[String] _))
@@ -1760,13 +1912,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop[A : TypedEncoder : Encoder]
- (condition1: Boolean, condition2: Boolean, value1: A, value2: A, otherwise: A) = {
- val ds = TypedDataset.create(X5(condition1, condition2, value1, value2, otherwise) :: Nil)
+ def prop[A: TypedEncoder: Encoder](
+ condition1: Boolean,
+ condition2: Boolean,
+ value1: A,
+ value2: A,
+ otherwise: A
+ ) = {
+ val ds = TypedDataset.create(
+ X5(condition1, condition2, value1, value2, otherwise) :: Nil
+ )
- val untypedWhen = ds.toDF()
+ val untypedWhen = ds
+ .toDF()
.select(
- sparkFunctions.when(sparkFunctions.col("a"), sparkFunctions.col("c"))
+ sparkFunctions
+ .when(sparkFunctions.col("a"), sparkFunctions.col("c"))
.when(sparkFunctions.col("b"), sparkFunctions.col("d"))
.otherwise(sparkFunctions.col("e"))
)
@@ -1776,9 +1937,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val typedWhen = ds
.select(
- when(ds('a), ds('c))
- .when(ds('b), ds('d))
- .otherwise(ds('e))
+ when(ds('a), ds('c)).when(ds('b), ds('d)).otherwise(ds('e))
)
.collect()
.run()
@@ -1800,17 +1959,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.ascii($"a"))
.map(_.getAs[Int](0))
.collect()
.toVector
- val typed = ds
- .select(ascii(ds('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(ascii(ds('a))).collect().run().toVector
typed ?= sparkResult
})
@@ -1828,19 +1984,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(pairs) { values: List[X2[String, String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.concat($"a", $"b"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(concat(ds('a), ds('b)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(concat(ds('a), ds('b))).collect().run().toVector
- (typed ?= sparkResult).&&(typed ?= values.map(x => s"${x.a}${x.b}").toVector)
+ (typed ?= sparkResult).&&(
+ typed ?= values.map(x => s"${x.a}${x.b}").toVector
+ )
})
}
@@ -1855,10 +2010,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(pairs) { values: List[X2[String, String]] =>
val ds = TypedDataset.create(values)
- val td = ds.agg(concat(first(ds('a)),first(ds('b)))).collect().run().toVector
- val spark = ds.dataset.select(sparkFunctions.concat(
- sparkFunctions.first($"a").as[String],
- sparkFunctions.first($"b").as[String])).as[String].collect().toVector
+ val td =
+ ds.agg(concat(first(ds('a)), first(ds('b)))).collect().run().toVector
+ val spark = ds.dataset
+ .select(
+ sparkFunctions.concat(
+ sparkFunctions.first($"a").as[String],
+ sparkFunctions.first($"b").as[String]
+ )
+ )
+ .as[String]
+ .collect()
+ .toVector
td ?= spark
})
}
@@ -1875,17 +2038,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(pairs) { values: List[X2[String, String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.concat_ws(",", $"a", $"b"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(concatWs(",", ds('a), ds('b)))
- .collect()
- .run()
- .toVector
+ val typed =
+ ds.select(concatWs(",", ds('a), ds('b))).collect().run().toVector
typed ?= sparkResult
})
@@ -1902,11 +2063,23 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(pairs) { values: List[X2[String, String]] =>
val ds = TypedDataset.create(values)
- val td = ds.agg(concatWs(",",first(ds('a)),first(ds('b)), last(ds('b)))).collect().run().toVector
- val spark = ds.dataset.select(sparkFunctions.concat_ws(",",
- sparkFunctions.first($"a").as[String],
- sparkFunctions.first($"b").as[String],
- sparkFunctions.last($"b").as[String])).as[String].collect().toVector
+ val td = ds
+ .agg(concatWs(",", first(ds('a)), first(ds('b)), last(ds('b))))
+ .collect()
+ .run()
+ .toVector
+ val spark = ds.dataset
+ .select(
+ sparkFunctions.concat_ws(
+ ",",
+ sparkFunctions.first($"a").as[String],
+ sparkFunctions.first($"b").as[String],
+ sparkFunctions.last($"b").as[String]
+ )
+ )
+ .as[String]
+ .collect()
+ .toVector
td ?= spark
})
}
@@ -1917,17 +2090,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(Gen.nonEmptyListOf(Gen.alphaStr)) { values: List[String] =>
val ds = TypedDataset.create(values.map(x => X1(x + values.head)))
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.instr($"a", values.head))
.map(_.getAs[Int](0))
.collect()
.toVector
- val typed = ds
- .select(instr(ds('a), values.head))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(instr(ds('a), values.head)).collect().run().toVector
typed ?= sparkResult
})
@@ -1939,17 +2109,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.length($"a"))
.map(_.getAs[Int](0))
.collect()
.toVector
- val typed = ds
- .select(length(ds[String]('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(length(ds[String]('a))).collect().run().toVector
(typed ?= sparkResult).&&(values.map(_.a.length).toVector ?= typed)
})
@@ -1961,26 +2128,43 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { (na: X1[String], values: List[X1[String]]) =>
val ds = TypedDataset.create(na +: values)
- val sparkResult = ds.toDF()
- .select(sparkFunctions.levenshtein($"a", sparkFunctions.concat($"a",sparkFunctions.lit("Hello"))))
+ val sparkResult = ds
+ .toDF()
+ .select(
+ sparkFunctions.levenshtein(
+ $"a",
+ sparkFunctions.concat($"a", sparkFunctions.lit("Hello"))
+ )
+ )
.map(_.getAs[Int](0))
.collect()
.toVector
val typed = ds
- .select(levenshtein(ds('a), concat(ds('a),lit("Hello"))))
+ .select(levenshtein(ds('a), concat(ds('a), lit("Hello"))))
.collect()
.run()
.toVector
val cDS = ds.dataset
- val aggrTyped = ds.agg(
- levenshtein(frameless.functions.aggregate.first(ds('a)), litAggr("Hello"))
- ).firstOption().run().get
+ val aggrTyped = ds
+ .agg(
+ levenshtein(
+ frameless.functions.aggregate.first(ds('a)),
+ litAggr("Hello")
+ )
+ )
+ .firstOption()
+ .run()
+ .get
- val aggrSpark = cDS.select(
- sparkFunctions.levenshtein(sparkFunctions.first("a"), sparkFunctions.lit("Hello")).as[Int]
- ).first()
+ val aggrSpark = cDS
+ .select(
+ sparkFunctions
+ .levenshtein(sparkFunctions.first("a"), sparkFunctions.lit("Hello"))
+ .as[Int]
+ )
+ .first()
(typed ?= sparkResult).&&(aggrTyped ?= aggrSpark)
})
@@ -1992,7 +2176,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { (values: List[X1[String]], n: Int) =>
val ds = TypedDataset.create(values.map(x => X1(s"$n${x.a}-$n$n")))
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.regexp_replace($"a", "\\d+", "n"))
.map(_.getAs[String](0))
.collect()
@@ -2014,17 +2199,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.reverse($"a"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(reverse(ds[String]('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(reverse(ds[String]('a))).collect().run().toVector
(typed ?= sparkResult).&&(values.map(_.a.reverse).toVector ?= typed)
})
@@ -2036,17 +2218,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.rpad($"a", 5, "hello"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(rpad(ds[String]('a), 5, "hello"))
- .collect()
- .run()
- .toVector
+ val typed =
+ ds.select(rpad(ds[String]('a), 5, "hello")).collect().run().toVector
typed ?= sparkResult
})
@@ -2058,17 +2238,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.lpad($"a", 5, "hello"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(lpad(ds[String]('a), 5, "hello"))
- .collect()
- .run()
- .toVector
+ val typed =
+ ds.select(lpad(ds[String]('a), 5, "hello")).collect().run().toVector
typed ?= sparkResult
})
@@ -2080,17 +2258,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} ")))
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.rtrim($"a"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(rtrim(ds[String]('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(rtrim(ds[String]('a))).collect().run().toVector
typed ?= sparkResult
})
@@ -2102,17 +2277,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} ")))
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.ltrim($"a"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(ltrim(ds[String]('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(ltrim(ds[String]('a))).collect().run().toVector
typed ?= sparkResult
})
@@ -2124,17 +2296,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values)
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.substring($"a", 5, 3))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(substring(ds[String]('a), 5, 3))
- .collect()
- .run()
- .toVector
+ val typed =
+ ds.select(substring(ds[String]('a), 5, 3)).collect().run().toVector
typed ?= sparkResult
})
@@ -2146,17 +2316,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll { values: List[X1[String]] =>
val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} ")))
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.trim($"a"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(trim(ds[String]('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(trim(ds[String]('a))).collect().run().toVector
typed ?= sparkResult
})
@@ -2168,17 +2335,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] =>
val ds = TypedDataset.create(values.map(X1(_)))
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.upper($"a"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(upper(ds[String]('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(upper(ds[String]('a))).collect().run().toVector
typed ?= sparkResult
})
@@ -2190,27 +2354,29 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] =>
val ds = TypedDataset.create(values.map(X1(_)))
- val sparkResult = ds.toDF()
+ val sparkResult = ds
+ .toDF()
.select(sparkFunctions.lower($"a"))
.map(_.getAs[String](0))
.collect()
.toVector
- val typed = ds
- .select(lower(ds[String]('a)))
- .collect()
- .run()
- .toVector
+ val typed = ds.select(lower(ds[String]('a))).collect().run().toVector
typed ?= sparkResult
})
}
test("Empty vararg tests") {
- def prop[A : TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = {
+ def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = {
val ds = TypedDataset.create(data)
- val frameless = ds.select(ds('a), concat(), ds('b), concatWs(":")).collect().run().toVector
- val framelessAggr = ds.agg(concat(), concatWs("x"), litAggr(2)).collect().run().toVector
+ val frameless = ds
+ .select(ds('a), concat(), ds('b), concatWs(":"))
+ .collect()
+ .run()
+ .toVector
+ val framelessAggr =
+ ds.agg(concat(), concatWs("x"), litAggr(2)).collect().run().toVector
val scala = data.map(x => (x.a, "", x.b, ""))
val scalaAggr = Vector(("", "", 2))
(frameless ?= scala).&&(framelessAggr ?= scalaAggr)
@@ -2220,8 +2386,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[Option[Boolean], Long] _))
}
- def dateTimeStringProp(typedDS: TypedDataset[X1[String]])
- (typedCol: TypedColumn[X1[String], Option[Int]], sparkFunc: Column => Column): Prop = {
+ def dateTimeStringProp(
+ typedDS: TypedDataset[X1[String]]
+ )(typedCol: TypedColumn[X1[String], Option[Int]],
+ sparkFunc: Column => Column
+ ): Prop = {
val spark = session
import spark.implicits._
@@ -2231,11 +2400,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
.collect()
.toList
- val typed = typedDS
- .select(typedCol)
- .collect()
- .run()
- .toList
+ val typed = typedDS.select(typedCol).collect().run().toList
typed ?= sparkResult
}
@@ -2244,10 +2409,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
- val ds = TypedDataset.create(data)
- dateTimeStringProp(ds)(year(ds[String]('a)), sparkFunctions.year)
- }
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
+ val ds = TypedDataset.create(data)
+ dateTimeStringProp(ds)(year(ds[String]('a)), sparkFunctions.year)
+ }
check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply))))
check(forAll(prop _))
@@ -2257,7 +2426,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
dateTimeStringProp(ds)(quarter(ds[String]('a)), sparkFunctions.quarter)
}
@@ -2270,7 +2443,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
dateTimeStringProp(ds)(month(ds[String]('a)), sparkFunctions.month)
}
@@ -2283,9 +2460,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- dateTimeStringProp(ds)(dayofweek(ds[String]('a)), sparkFunctions.dayofweek)
+ dateTimeStringProp(ds)(
+ dayofweek(ds[String]('a)),
+ sparkFunctions.dayofweek
+ )
}
check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply))))
@@ -2296,9 +2480,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- dateTimeStringProp(ds)(dayofmonth(ds[String]('a)), sparkFunctions.dayofmonth)
+ dateTimeStringProp(ds)(
+ dayofmonth(ds[String]('a)),
+ sparkFunctions.dayofmonth
+ )
}
check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply))))
@@ -2309,9 +2500,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- dateTimeStringProp(ds)(dayofyear(ds[String]('a)), sparkFunctions.dayofyear)
+ dateTimeStringProp(ds)(
+ dayofyear(ds[String]('a)),
+ sparkFunctions.dayofyear
+ )
}
check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply))))
@@ -2322,7 +2520,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
dateTimeStringProp(ds)(hour(ds[String]('a)), sparkFunctions.hour)
}
@@ -2335,7 +2537,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
dateTimeStringProp(ds)(minute(ds[String]('a)), sparkFunctions.minute)
}
@@ -2348,7 +2554,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
dateTimeStringProp(ds)(second(ds[String]('a)), sparkFunctions.second)
}
@@ -2361,9 +2571,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._
- def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = {
+ def prop(
+ data: List[X1[String]]
+ )(implicit
+ E: Encoder[Option[Int]]
+ ): Prop = {
val ds = TypedDataset.create(data)
- dateTimeStringProp(ds)(weekofyear(ds[String]('a)), sparkFunctions.weekofyear)
+ dateTimeStringProp(ds)(
+ weekofyear(ds[String]('a)),
+ sparkFunctions.weekofyear
+ )
}
check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply))))
diff --git a/dataset/src/test/scala/frameless/functions/UdfTests.scala b/dataset/src/test/scala/frameless/functions/UdfTests.scala
index 10e65180f..ed1039640 100644
--- a/dataset/src/test/scala/frameless/functions/UdfTests.scala
+++ b/dataset/src/test/scala/frameless/functions/UdfTests.scala
@@ -7,14 +7,22 @@ import org.scalacheck.Prop._
class UdfTests extends TypedDatasetSuite {
test("one argument udf") {
- def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X1[A]], f1: A => B): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ data: Vector[X1[A]],
+ f1: A => B
+ ): Prop = {
val dataset: TypedDataset[X1[A]] = TypedDataset.create(data)
val u1 = udf[X1[A], A, B](f1)
val u2 = dataset.makeUDF(f1)
val A = dataset.col[A]('a)
// filter forces whole codegen
- val codegen = dataset.deserialized.filter((_:X1[A]) => true).select(u1(A)).collect().run().toVector
+ val codegen = dataset.deserialized
+ .filter((_: X1[A]) => true)
+ .select(u1(A))
+ .collect()
+ .run()
+ .toVector
// otherwise it uses local relation
val local = dataset.select(u2(A)).collect().run().toVector
@@ -35,15 +43,22 @@ class UdfTests extends TypedDatasetSuite {
check(forAll(prop[Option[Vector[String]], Option[Vector[String]]] _))
- def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = prop(Vector(X1(a)), f)
+ def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop =
+ prop(Vector(X1(a)), f)
- check(forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _))
+ check(
+ forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _)
+ )
check(forAll(prop2[Option[Int], Int](x => x getOrElse 0) _))
}
test("multiple one argument udf") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]
- (data: Vector[X3[A, B, C]], f1: A => A, f2: B => B, f3: C => C): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](
+ data: Vector[X3[A, B, C]],
+ f1: A => A,
+ f2: B => B,
+ f3: C => C
+ ): Prop = {
val dataset = TypedDataset.create(data)
val u11 = udf[X3[A, B, C], A, A](f1)
val u21 = udf[X3[A, B, C], B, B](f2)
@@ -55,8 +70,10 @@ class UdfTests extends TypedDatasetSuite {
val B = dataset.col[B]('b)
val C = dataset.col[C]('c)
- val dataset21 = dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector
- val dataset22 = dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector
+ val dataset21 =
+ dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector
+ val dataset22 =
+ dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector
val d = data.map(x => (f1(x.a), f2(x.b), f3(x.c)))
(dataset21 ?= d) && (dataset22 ?= d)
@@ -69,8 +86,10 @@ class UdfTests extends TypedDatasetSuite {
}
test("two argument udf") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]
- (data: Vector[X3[A, B, C]], f1: (A, B) => C): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](
+ data: Vector[X3[A, B, C]],
+ f1: (A, B) => C
+ ): Prop = {
val dataset = TypedDataset.create(data)
val u1 = udf[X3[A, B, C], A, B, C](f1)
val u2 = dataset.makeUDF(f1)
@@ -89,8 +108,11 @@ class UdfTests extends TypedDatasetSuite {
}
test("multiple two argument udf") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]
- (data: Vector[X3[A, B, C]], f1: (A, B) => C, f2: (B, C) => A): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](
+ data: Vector[X3[A, B, C]],
+ f1: (A, B) => C,
+ f2: (B, C) => A
+ ): Prop = {
val dataset = TypedDataset.create(data)
val u11 = udf[X3[A, B, C], A, B, C](f1)
val u12 = dataset.makeUDF(f1)
@@ -101,8 +123,10 @@ class UdfTests extends TypedDatasetSuite {
val B = dataset.col[B]('b)
val C = dataset.col[C]('c)
- val dataset21 = dataset.select(u11(A, B), u21(B, C)).collect().run().toVector
- val dataset22 = dataset.select(u12(A, B), u22(B, C)).collect().run().toVector
+ val dataset21 =
+ dataset.select(u11(A, B), u21(B, C)).collect().run().toVector
+ val dataset22 =
+ dataset.select(u12(A, B), u22(B, C)).collect().run().toVector
val d = data.map(x => (f1(x.a, x.b), f2(x.b, x.c)))
(dataset21 ?= d) && (dataset22 ?= d)
@@ -113,8 +137,10 @@ class UdfTests extends TypedDatasetSuite {
}
test("three argument udf") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]
- (data: Vector[X3[A, B, C]], f: (A, B, C) => C): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](
+ data: Vector[X3[A, B, C]],
+ f: (A, B, C) => C
+ ): Prop = {
val dataset = TypedDataset.create(data)
val u1 = udf[X3[A, B, C], A, B, C, C](f)
val u2 = dataset.makeUDF(f)
@@ -135,8 +161,14 @@ class UdfTests extends TypedDatasetSuite {
}
test("four argument udf") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder]
- (data: Vector[X4[A, B, C, D]], f: (A, B, C, D) => C): Prop = {
+ def prop[
+ A: TypedEncoder,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ D: TypedEncoder
+ ](data: Vector[X4[A, B, C, D]],
+ f: (A, B, C, D) => C
+ ): Prop = {
val dataset = TypedDataset.create(data)
val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f)
val u2 = dataset.makeUDF(f)
@@ -161,8 +193,15 @@ class UdfTests extends TypedDatasetSuite {
}
test("five argument udf") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder, E: TypedEncoder]
- (data: Vector[X5[A, B, C, D, E]], f: (A, B, C, D, E) => C): Prop = {
+ def prop[
+ A: TypedEncoder,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ D: TypedEncoder,
+ E: TypedEncoder
+ ](data: Vector[X5[A, B, C, D, E]],
+ f: (A, B, C, D, E) => C
+ ): Prop = {
val dataset = TypedDataset.create(data)
val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f)
val u2 = dataset.makeUDF(f)
diff --git a/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala b/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala
index 009179be6..23335f519 100644
--- a/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala
+++ b/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala
@@ -10,7 +10,12 @@ import scala.reflect.ClassTag
class UnaryFunctionsTest extends TypedDatasetSuite {
test("size tests") {
- def prop[F[X] <: Traversable[X] : CatalystSizableCollection, A](xs: List[X1[F[A]]])(implicit arb: Arbitrary[F[A]], enc: TypedEncoder[F[A]]): Prop = {
+ def prop[F[X] <: Traversable[X]: CatalystSizableCollection, A](
+ xs: List[X1[F[A]]]
+ )(implicit
+ arb: Arbitrary[F[A]],
+ enc: TypedEncoder[F[A]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
val framelessResults = tds.select(size(tds('a))).collect().run().toVector
@@ -43,7 +48,12 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}
test("size on Map") {
- def prop[A](xs: List[X1[Map[A, A]]])(implicit arb: Arbitrary[Map[A, A]], enc: TypedEncoder[Map[A, A]]): Prop = {
+ def prop[A](
+ xs: List[X1[Map[A, A]]]
+ )(implicit
+ arb: Arbitrary[Map[A, A]],
+ enc: TypedEncoder[Map[A, A]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
val framelessResults = tds.select(size(tds('a))).collect().run().toVector
@@ -58,10 +68,15 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}
test("sort in ascending order") {
- def prop[F[X] <: SeqLike[X, F[X]] : CatalystSortableCollection, A: Ordering](xs: List[X1[F[A]]])(implicit enc: TypedEncoder[F[A]]): Prop = {
+ def prop[F[X] <: SeqLike[X, F[X]]: CatalystSortableCollection, A: Ordering](
+ xs: List[X1[F[A]]]
+ )(implicit
+ enc: TypedEncoder[F[A]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
- val framelessResults = tds.select(sortAscending(tds('a))).collect().run().toVector
+ val framelessResults =
+ tds.select(sortAscending(tds('a))).collect().run().toVector
val scalaResults = xs.map(x => x.a.sorted).toVector
framelessResults ?= scalaResults
@@ -78,10 +93,15 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}
test("sort in descending order") {
- def prop[F[X] <: SeqLike[X, F[X]] : CatalystSortableCollection, A: Ordering](xs: List[X1[F[A]]])(implicit enc: TypedEncoder[F[A]]): Prop = {
+ def prop[F[X] <: SeqLike[X, F[X]]: CatalystSortableCollection, A: Ordering](
+ xs: List[X1[F[A]]]
+ )(implicit
+ enc: TypedEncoder[F[A]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
- val framelessResults = tds.select(sortDescending(tds('a))).collect().run().toVector
+ val framelessResults =
+ tds.select(sortDescending(tds('a))).collect().run().toVector
val scalaResults = xs.map(x => x.a.sorted.reverse).toVector
framelessResults ?= scalaResults
@@ -98,18 +118,19 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}
test("sort on array test: ascending order") {
- def prop[A: TypedEncoder : Ordering : ClassTag](xs: List[X1[Array[A]]]): Prop = {
+ def prop[A: TypedEncoder: Ordering: ClassTag](
+ xs: List[X1[Array[A]]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
- val framelessResults = tds.select(sortAscending(tds('a))).collect().run().toVector
+ val framelessResults =
+ tds.select(sortAscending(tds('a))).collect().run().toVector
val scalaResults = xs.map(x => x.a.sorted).toVector
Prop {
- framelessResults
- .zip(scalaResults)
- .forall {
- case (a, b) => a sameElements b
- }
+ framelessResults.zip(scalaResults).forall {
+ case (a, b) => a sameElements b
+ }
}
}
@@ -119,18 +140,19 @@ class UnaryFunctionsTest extends TypedDatasetSuite {
}
test("sort on array test: descending order") {
- def prop[A: TypedEncoder : Ordering : ClassTag](xs: List[X1[Array[A]]]): Prop = {
+ def prop[A: TypedEncoder: Ordering: ClassTag](
+ xs: List[X1[Array[A]]]
+ ): Prop = {
val tds = TypedDataset.create(xs)
- val framelessResults = tds.select(sortDescending(tds('a))).collect().run().toVector
+ val framelessResults =
+ tds.select(sortDescending(tds('a))).collect().run().toVector
val scalaResults = xs.map(x => x.a.sorted.reverse).toVector
Prop {
- framelessResults
- .zip(scalaResults)
- .forall {
- case (a, b) => a sameElements b
- }
+ framelessResults.zip(scalaResults).forall {
+ case (a, b) => a sameElements b
+ }
}
}
diff --git a/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala b/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala
index 303eb2cbd..df094c7a2 100644
--- a/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala
+++ b/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala
@@ -8,20 +8,30 @@ import shapeless.::
class ColumnTypesTest extends TypedDatasetSuite {
test("test summoning") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = {
+ def prop[
+ A: TypedEncoder,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ D: TypedEncoder
+ ](data: Vector[X4[A, B, C, D]]
+ ): Prop = {
val d: TypedDataset[X4[A, B, C, D]] = TypedDataset.create(data)
val hlist = d('a) :: d('b) :: d('c) :: d('d) :: HNil
- type TC[N] = TypedColumn[X4[A,B,C,D], N]
+ type TC[N] = TypedColumn[X4[A, B, C, D], N]
type IN = TC[A] :: TC[B] :: TC[C] :: TC[D] :: HNil
type OUT = A :: B :: C :: D :: HNil
- implicitly[ColumnTypes.Aux[X4[A,B,C,D], IN, OUT]]
+ implicitly[ColumnTypes.Aux[X4[A, B, C, D], IN, OUT]]
Prop.passed // successful compilation implies test correctness
}
check(forAll(prop[Int, String, X1[String], Boolean] _))
- check(forAll(prop[Vector[Int], Vector[Vector[String]], X1[String], Option[String]] _))
+ check(
+ forAll(
+ prop[Vector[Int], Vector[Vector[String]], X1[String], Option[String]] _
+ )
+ )
}
}
diff --git a/dataset/src/test/scala/frameless/ops/CubeTests.scala b/dataset/src/test/scala/frameless/ops/CubeTests.scala
index 7a06822b9..ae4c72d69 100644
--- a/dataset/src/test/scala/frameless/ops/CubeTests.scala
+++ b/dataset/src/test/scala/frameless/ops/CubeTests.scala
@@ -8,14 +8,23 @@ import org.scalacheck.Prop._
class CubeTests extends TypedDatasetSuite {
test("cube('a).agg(count())") {
- def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric]
- (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = {
+ def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric](
+ data: List[X1[A]]
+ )(implicit
+ summable: CatalystSummable[A, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
- val received = dataset.cube(A).agg(count()).collect().run().toVector.sortBy(_._2)
- val expected = dataset.dataset.cube("a").count().collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2)
+ val received =
+ dataset.cube(A).agg(count()).collect().run().toVector.sortBy(_._2)
+ val expected = dataset.dataset
+ .cube("a")
+ .count()
+ .collect()
+ .toVector
+ .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1)))
+ .sortBy(_._2)
received ?= expected
}
@@ -24,15 +33,29 @@ class CubeTests extends TypedDatasetSuite {
}
test("cube('a, 'b).agg(count())") {
- def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric]
- (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = {
+ def prop[
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ Out: TypedEncoder: Numeric
+ ](data: List[X2[A, B]]
+ )(implicit
+ summable: CatalystSummable[B, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
- val received = dataset.cube(A, B).agg(count()).collect().run().toVector.sortBy(_._3)
- val expected = dataset.dataset.cube("a", "b").count().collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))).sortBy(_._3)
+ val received =
+ dataset.cube(A, B).agg(count()).collect().run().toVector.sortBy(_._3)
+ val expected = dataset.dataset
+ .cube("a", "b")
+ .count()
+ .collect()
+ .toVector
+ .map(row =>
+ (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))
+ )
+ .sortBy(_._3)
received ?= expected
}
@@ -41,15 +64,27 @@ class CubeTests extends TypedDatasetSuite {
}
test("cube('a).agg(sum('b)") {
- def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric]
- (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = {
+ def prop[
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ Out: TypedEncoder: Numeric
+ ](data: List[X2[A, B]]
+ )(implicit
+ summable: CatalystSummable[B, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
- val received = dataset.cube(A).agg(sum(B)).collect().run().toVector.sortBy(_._2)
- val expected = dataset.dataset.cube("a").sum("b").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))).sortBy(_._2)
+ val received =
+ dataset.cube(A).agg(sum(B)).collect().run().toVector.sortBy(_._2)
+ val expected = dataset.dataset
+ .cube("a")
+ .sum("b")
+ .collect()
+ .toVector
+ .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1)))
+ .sortBy(_._2)
received ?= expected
}
@@ -58,15 +93,22 @@ class CubeTests extends TypedDatasetSuite {
}
test("cube('a).mapGroups('a, sum('b))") {
- def prop[A: TypedEncoder : Ordering, B: TypedEncoder : Numeric]
- (data: List[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric](
+ data: List[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
- val received = dataset.cube(A)
- .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) }
- .collect().run().toVector.sortBy(_._1)
- val expected = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1)
+ val received = dataset
+ .cube(A)
+ .deserialized
+ .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) }
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
+ val expected =
+ data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1)
received ?= expected
}
@@ -76,16 +118,16 @@ class CubeTests extends TypedDatasetSuite {
test("cube('a).agg(sum('b), sum('c)) to cube('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder,
- C: TypedEncoder,
- OutB: TypedEncoder : Numeric,
- OutC: TypedEncoder : Numeric
- ](data: List[X3[A, B, C]])(
- implicit
- summableB: CatalystSummable[B, OutB],
- summableC: CatalystSummable[C, OutC]
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ OutB: TypedEncoder: Numeric,
+ OutC: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ )(implicit
+ summableB: CatalystSummable[B, OutB],
+ summableC: CatalystSummable[C, OutC]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -94,37 +136,91 @@ class CubeTests extends TypedDatasetSuite {
val framelessSumBC = dataset
.cube(A)
.agg(sum(B), sum(C))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBC = dataset.dataset.cube("a").sum("b", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2)))
+ val sparkSumBC = dataset.dataset
+ .cube("a")
+ .sum("b", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2))
+ )
.sortBy(_._1)
val framelessSumBCB = dataset
.cube(A)
.agg(sum(B), sum(C), sum(B))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBCB = dataset.dataset.cube("a").sum("b", "c", "b").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3)))
+ val sparkSumBCB = dataset.dataset
+ .cube("a")
+ .sum("b", "c", "b")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ row.getAs[OutB](1),
+ row.getAs[OutC](2),
+ row.getAs[OutB](3)
+ )
+ )
.sortBy(_._1)
val framelessSumBCBC = dataset
.cube(A)
.agg(sum(B), sum(C), sum(B), sum(C))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBCBC = dataset.dataset.cube("a").sum("b", "c", "b", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4)))
+ val sparkSumBCBC = dataset.dataset
+ .cube("a")
+ .sum("b", "c", "b", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ row.getAs[OutB](1),
+ row.getAs[OutC](2),
+ row.getAs[OutB](3),
+ row.getAs[OutC](4)
+ )
+ )
.sortBy(_._1)
val framelessSumBCBCB = dataset
.cube(A)
.agg(sum(B), sum(C), sum(B), sum(C), sum(B))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBCBCB = dataset.dataset.cube("a").sum("b", "c", "b", "c", "b").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4), row.getAs[OutB](5)))
+ val sparkSumBCBCB = dataset.dataset
+ .cube("a")
+ .sum("b", "c", "b", "c", "b")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ row.getAs[OutB](1),
+ row.getAs[OutC](2),
+ row.getAs[OutB](3),
+ row.getAs[OutC](4),
+ row.getAs[OutB](5)
+ )
+ )
.sortBy(_._1)
(framelessSumBC ?= sparkSumBC)
@@ -138,17 +234,17 @@ class CubeTests extends TypedDatasetSuite {
test("cube('a, 'b).agg(sum('c), sum('d))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder,
- D: TypedEncoder,
- OutC: TypedEncoder : Numeric,
- OutD: TypedEncoder : Numeric
- ](data: List[X4[A, B, C, D]])(
- implicit
- summableC: CatalystSummable[C, OutC],
- summableD: CatalystSummable[D, OutD]
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder,
+ D: TypedEncoder,
+ OutC: TypedEncoder: Numeric,
+ OutD: TypedEncoder: Numeric
+ ](data: List[X4[A, B, C, D]]
+ )(implicit
+ summableC: CatalystSummable[C, OutC],
+ summableD: CatalystSummable[D, OutD]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -158,11 +254,24 @@ class CubeTests extends TypedDatasetSuite {
val framelessSumByAB = dataset
.cube(A, B)
.agg(sum(C), sum(D))
- .collect().run().toVector.sortBy(x => (x._1, x._2))
+ .collect()
+ .run()
+ .toVector
+ .sortBy(x => (x._1, x._2))
val sparkSumByAB = dataset.dataset
- .cube("a", "b").sum("c", "d").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutD](3)))
+ .cube("a", "b")
+ .sum("c", "d")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutD](3)
+ )
+ )
.sortBy(x => (x._1, x._2))
framelessSumByAB ?= sparkSumByAB
@@ -173,76 +282,135 @@ class CubeTests extends TypedDatasetSuite {
test("cube('a, 'b).agg(sum('c)) to cube('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder,
- OutC: TypedEncoder: Numeric
- ](data: List[X3[A, B, C]])(implicit summableC: CatalystSummable[C, OutC]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder,
+ OutC: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ )(implicit
+ summableC: CatalystSummable[C, OutC]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
val C = dataset.col[C]('c)
- val framelessSumC = dataset
- .cube(A, B)
- .agg(sum(C))
- .collect().run().toVector
- .sortBy(_._2)
+ val framelessSumC =
+ dataset.cube(A, B).agg(sum(C)).collect().run().toVector.sortBy(_._2)
val sparkSumC = dataset.dataset
- .cube("a", "b").sum("c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2)))
+ .cube("a", "b")
+ .sum("c")
+ .collect()
+ .toVector
+ .map(row =>
+ (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2))
+ )
.sortBy(_._2)
val framelessSumCC = dataset
.cube(A, B)
.agg(sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCC = dataset.dataset
- .cube("a", "b").sum("c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3)))
+ .cube("a", "b")
+ .sum("c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3)
+ )
+ )
.sortBy(_._2)
val framelessSumCCC = dataset
.cube(A, B)
.agg(sum(C), sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCCC = dataset.dataset
- .cube("a", "b").sum("c", "c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4)))
+ .cube("a", "b")
+ .sum("c", "c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3),
+ row.getAs[OutC](4)
+ )
+ )
.sortBy(_._2)
val framelessSumCCCC = dataset
.cube(A, B)
.agg(sum(C), sum(C), sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCCCC = dataset.dataset
- .cube("a", "b").sum("c", "c", "c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5)))
+ .cube("a", "b")
+ .sum("c", "c", "c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3),
+ row.getAs[OutC](4),
+ row.getAs[OutC](5)
+ )
+ )
.sortBy(_._2)
val framelessSumCCCCC = dataset
.cube(A, B)
.agg(sum(C), sum(C), sum(C), sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCCCCC = dataset.dataset
- .cube("a", "b").sum("c", "c", "c", "c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5), row.getAs[OutC](6)))
+ .cube("a", "b")
+ .sum("c", "c", "c", "c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3),
+ row.getAs[OutC](4),
+ row.getAs[OutC](5),
+ row.getAs[OutC](6)
+ )
+ )
.sortBy(_._2)
(framelessSumC ?= sparkSumC) &&
- (framelessSumCC ?= sparkSumCC) &&
- (framelessSumCCC ?= sparkSumCCC) &&
- (framelessSumCCCC ?= sparkSumCCCC) &&
- (framelessSumCCCCC ?= sparkSumCCCCC)
+ (framelessSumCC ?= sparkSumCC) &&
+ (framelessSumCCC ?= sparkSumCCC) &&
+ (framelessSumCCCC ?= sparkSumCCCC) &&
+ (framelessSumCCCCC ?= sparkSumCCCCC)
}
check(forAll(prop[String, Long, Double, Double] _))
@@ -250,22 +418,30 @@ class CubeTests extends TypedDatasetSuite {
test("cube('a, 'b).mapGroups('a, 'b, sum('c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder : Numeric
- ](data: List[X3[A, B, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
val framelessSumByAB = dataset
.cube(A, B)
- .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) }
- .collect().run().toVector.sortBy(x => (x._1, x._2))
+ .deserialized
+ .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) }
+ .collect()
+ .run()
+ .toVector
+ .sortBy(x => (x._1, x._2))
- val sumByAB = data.groupBy(x => (x.a, x.b))
+ val sumByAB = data
+ .groupBy(x => (x.a, x.b))
.mapValues { xs => xs.map(_.c).sum }
- .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2))
+ .toVector
+ .map { case ((a, b), c) => (a, b, c) }
+ .sortBy(x => (x._1, x._2))
framelessSumByAB ?= sumByAB
}
@@ -274,17 +450,19 @@ class CubeTests extends TypedDatasetSuite {
}
test("cube('a).mapGroups(('a, toVector(('a, 'b))") {
- def prop[
- A: TypedEncoder: Ordering,
- B: TypedEncoder: Ordering,
- ](data: Vector[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ data: Vector[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val datasetGrouped = dataset
.cube(A)
- .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted))
- .collect().run().toMap
+ .deserialized
+ .mapGroups((a, xs) => (a, xs.toVector.sorted))
+ .collect()
+ .run()
+ .toMap
val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted }
@@ -297,21 +475,23 @@ class CubeTests extends TypedDatasetSuite {
}
test("cube('a).flatMapGroups(('a, toVector(('a, 'b))") {
- def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering
- ](data: Vector[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ data: Vector[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val datasetGrouped = dataset
.cube(A)
- .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x)))
- .collect().run()
+ .deserialized
+ .flatMapGroups((a, xs) => xs.map(x => (a, x)))
+ .collect()
+ .run()
.sorted
val dataGrouped = data
- .groupBy(_.a).toSeq
+ .groupBy(_.a)
+ .toSeq
.flatMap { case (a, xs) => xs.map(x => (a, x)) }
.sorted
@@ -325,22 +505,26 @@ class CubeTests extends TypedDatasetSuite {
test("cube('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder : Ordering
- ](data: Vector[X3[A, B, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](data: Vector[X3[A, B, C]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val cA = dataset.col[A]('a)
val cB = dataset.col[B]('b)
val datasetGrouped = dataset
.cube(cA, cB)
- .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x)))
- .collect().run()
+ .deserialized
+ .flatMapGroups((a, xs) => xs.map(x => (a, x)))
+ .collect()
+ .run()
.sorted
val dataGrouped = data
- .groupBy(t => (t.a, t.b)).toSeq
+ .groupBy(t => (t.a, t.b))
+ .toSeq
.flatMap { case (a, xs) => xs.map(x => (a, x)) }
.sorted
@@ -353,18 +537,32 @@ class CubeTests extends TypedDatasetSuite {
}
test("cubeMany('a).agg(sum('b))") {
- def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric]
- (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = {
+ def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric](
+ data: List[X1[A]]
+ )(implicit
+ summable: CatalystSummable[A, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
- val received = dataset.cubeMany(A).agg(count[X1[A]]()).collect().run().toVector.sortBy(_._2)
- val expected = dataset.dataset.cube("a").count().collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2)
+ val received = dataset
+ .cubeMany(A)
+ .agg(count[X1[A]]())
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._2)
+ val expected = dataset.dataset
+ .cube("a")
+ .count()
+ .collect()
+ .toVector
+ .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1)))
+ .sortBy(_._2)
received ?= expected
}
check(forAll(prop[Int, Long] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/ops/PivotTest.scala b/dataset/src/test/scala/frameless/ops/PivotTest.scala
index dd9bf5e61..f3f2bbcf9 100644
--- a/dataset/src/test/scala/frameless/ops/PivotTest.scala
+++ b/dataset/src/test/scala/frameless/ops/PivotTest.scala
@@ -2,12 +2,13 @@ package frameless
package ops
import frameless.functions.aggregate._
-import org.apache.spark.sql.{functions => sparkFunctions}
+import org.apache.spark.sql.{ functions => sparkFunctions }
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Prop._
-import org.scalacheck.{Gen, Prop}
+import org.scalacheck.{ Gen, Prop }
class PivotTest extends TypedDatasetSuite {
+
def withCustomGenX4: Gen[Vector[X4[String, String, Int, Boolean]]] = {
val kvPairGen: Gen[X4[String, String, Int, Boolean]] = for {
a <- Gen.oneOf(Seq("1", "2", "3", "4"))
@@ -22,77 +23,109 @@ class PivotTest extends TypedDatasetSuite {
test("X4[Boolean, String, Int, Boolean] pivot on String") {
def prop(data: Vector[X4[String, String, Int, Boolean]]): Prop = {
val d = TypedDataset.create(data)
- val frameless = d.groupBy(d('a)).
- pivot(d('b)).on("a", "b", "c").
- agg(sum(d('c)), first(d('d))).collect().run().toVector
+ val frameless = d
+ .groupBy(d('a))
+ .pivot(d('b))
+ .on("a", "b", "c")
+ .agg(sum(d('c)), first(d('d)))
+ .collect()
+ .run()
+ .toVector
- val spark = d.dataset.groupBy("a")
+ val spark = d.dataset
+ .groupBy("a")
.pivot("b", Seq("a", "b", "c"))
- .agg(sparkFunctions.sum("c"), sparkFunctions.first("d")).collect().toVector
+ .agg(sparkFunctions.sum("c"), sparkFunctions.first("d"))
+ .collect()
+ .toVector
- (frameless.map(_._1) ?= spark.map(x => x.getAs[String](0))).&&(
- frameless.map(_._2) ?= spark.map(x => Option(x.getAs[Long](1)))).&&(
- frameless.map(_._3) ?= spark.map(x => Option(x.getAs[Boolean](2)))).&&(
- frameless.map(_._4) ?= spark.map(x => Option(x.getAs[Long](3)))).&&(
- frameless.map(_._5) ?= spark.map(x => Option(x.getAs[Boolean](4)))).&&(
- frameless.map(_._6) ?= spark.map(x => Option(x.getAs[Long](5)))).&&(
- frameless.map(_._7) ?= spark.map(x => Option(x.getAs[Boolean](6))))
+ (frameless.map(_._1) ?= spark.map(x => x.getAs[String](0)))
+ .&&(frameless.map(_._2) ?= spark.map(x => Option(x.getAs[Long](1))))
+ .&&(frameless.map(_._3) ?= spark.map(x => Option(x.getAs[Boolean](2))))
+ .&&(frameless.map(_._4) ?= spark.map(x => Option(x.getAs[Long](3))))
+ .&&(frameless.map(_._5) ?= spark.map(x => Option(x.getAs[Boolean](4))))
+ .&&(frameless.map(_._6) ?= spark.map(x => Option(x.getAs[Long](5))))
+ .&&(frameless.map(_._7) ?= spark.map(x => Option(x.getAs[Boolean](6))))
}
check(forAll(withCustomGenX4)(prop))
}
test("Pivot on Boolean") {
- val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false))
+ val x: Seq[X3[String, Boolean, Boolean]] =
+ Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false))
val d = TypedDataset.create(x)
- d.groupByMany(d('a)).
- pivot(d('c)).on(true, false).
- agg(count[X3[String, Boolean, Boolean]]()).
- collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) // two true one false
+ d.groupByMany(d('a))
+ .pivot(d('c))
+ .on(true, false)
+ .agg(count[X3[String, Boolean, Boolean]]())
+ .collect()
+ .run()
+ .toVector ?= Vector(("a", Some(2L), Some(1L))) // two true one false
}
test("Pivot with groupBy on two columns, pivot on Long") {
- val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20))
+ val x: Seq[X3[String, String, Long]] =
+ Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20))
val d = TypedDataset.create(x)
- d.groupBy(d('a), d('b)).
- pivot(d('c)).on(1L, 20L).
- agg(count[X3[String, String, Long]]()).
- collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L)))
+ d.groupBy(d('a), d('b))
+ .pivot(d('c))
+ .on(1L, 20L)
+ .agg(count[X3[String, String, Long]]())
+ .collect()
+ .run()
+ .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L)))
}
test("Pivot with cube on two columns, pivot on Long") {
- val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20))
+ val x: Seq[X3[String, String, Long]] =
+ Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20))
val d = TypedDataset.create(x)
d.cube(d('a), d('b))
- .pivot(d('c)).on(1L, 20L)
+ .pivot(d('c))
+ .on(1L, 20L)
.agg(count[X3[String, String, Long]]())
- .collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L)))
+ .collect()
+ .run()
+ .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L)))
}
test("Pivot with cube on Boolean") {
- val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false))
+ val x: Seq[X3[String, Boolean, Boolean]] =
+ Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false))
val d = TypedDataset.create(x)
- d.cube(d('a)).
- pivot(d('c)).on(true, false).
- agg(count[X3[String, Boolean, Boolean]]()).
- collect().run().toVector ?= Vector(("a", Some(2L), Some(1L)))
+ d.cube(d('a))
+ .pivot(d('c))
+ .on(true, false)
+ .agg(count[X3[String, Boolean, Boolean]]())
+ .collect()
+ .run()
+ .toVector ?= Vector(("a", Some(2L), Some(1L)))
}
test("Pivot with rollup on two columns, pivot on Long") {
- val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20))
+ val x: Seq[X3[String, String, Long]] =
+ Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20))
val d = TypedDataset.create(x)
d.rollup(d('a), d('b))
- .pivot(d('c)).on(1L, 20L)
+ .pivot(d('c))
+ .on(1L, 20L)
.agg(count[X3[String, String, Long]]())
- .collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L)))
+ .collect()
+ .run()
+ .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L)))
}
test("Pivot with rollup on Boolean") {
- val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false))
+ val x: Seq[X3[String, Boolean, Boolean]] =
+ Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false))
val d = TypedDataset.create(x)
- d.rollupMany(d('a)).
- pivot(d('c)).on(true, false).
- agg(count[X3[String, Boolean, Boolean]]()).
- collect().run().toVector ?= Vector(("a", Some(2L), Some(1L)))
+ d.rollupMany(d('a))
+ .pivot(d('c))
+ .on(true, false)
+ .agg(count[X3[String, Boolean, Boolean]]())
+ .collect()
+ .run()
+ .toVector ?= Vector(("a", Some(2L), Some(1L)))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/ops/RepeatTest.scala b/dataset/src/test/scala/frameless/ops/RepeatTest.scala
index 78dfc6410..0dd6905fb 100644
--- a/dataset/src/test/scala/frameless/ops/RepeatTest.scala
+++ b/dataset/src/test/scala/frameless/ops/RepeatTest.scala
@@ -2,17 +2,31 @@ package frameless
package ops
import shapeless.test.illTyped
-import shapeless.{::, HNil, Nat}
+import shapeless.{ ::, HNil, Nat }
class RepeatTest extends TypedDatasetSuite {
test("summoning with implicitly") {
- implicitly[Repeat.Aux[Int::Boolean::HNil, Nat._1, Int::Boolean::HNil]]
- implicitly[Repeat.Aux[Int::Boolean::HNil, Nat._2, Int::Boolean::Int::Boolean::HNil]]
- implicitly[Repeat.Aux[Int::Boolean::HNil, Nat._3, Int::Boolean::Int::Boolean::Int::Boolean::HNil]]
- implicitly[Repeat.Aux[String::HNil, Nat._5, String::String::String::String::String::HNil]]
+ implicitly[
+ Repeat.Aux[Int :: Boolean :: HNil, Nat._1, Int :: Boolean :: HNil]
+ ]
+ implicitly[Repeat.Aux[
+ Int :: Boolean :: HNil,
+ Nat._2,
+ Int :: Boolean :: Int :: Boolean :: HNil
+ ]]
+ implicitly[Repeat.Aux[
+ Int :: Boolean :: HNil,
+ Nat._3,
+ Int :: Boolean :: Int :: Boolean :: Int :: Boolean :: HNil
+ ]]
+ implicitly[Repeat.Aux[
+ String :: HNil,
+ Nat._5,
+ String :: String :: String :: String :: String :: HNil
+ ]]
}
test("ill typed") {
illTyped("""implicitly[Repeat.Aux[String::HNil, Nat._5, String::String::String::String::HNil]]""")
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/ops/RollupTests.scala b/dataset/src/test/scala/frameless/ops/RollupTests.scala
index da73ef8d0..cdccba47d 100644
--- a/dataset/src/test/scala/frameless/ops/RollupTests.scala
+++ b/dataset/src/test/scala/frameless/ops/RollupTests.scala
@@ -8,14 +8,23 @@ import org.scalacheck.Prop._
class RollupTests extends TypedDatasetSuite {
test("rollup('a).agg(count())") {
- def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric]
- (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = {
+ def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric](
+ data: List[X1[A]]
+ )(implicit
+ summable: CatalystSummable[A, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
- val received = dataset.rollup(A).agg(count()).collect().run().toVector.sortBy(_._2)
- val expected = dataset.dataset.rollup("a").count().collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2)
+ val received =
+ dataset.rollup(A).agg(count()).collect().run().toVector.sortBy(_._2)
+ val expected = dataset.dataset
+ .rollup("a")
+ .count()
+ .collect()
+ .toVector
+ .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1)))
+ .sortBy(_._2)
received ?= expected
}
@@ -24,15 +33,29 @@ class RollupTests extends TypedDatasetSuite {
}
test("rollup('a, 'b).agg(count())") {
- def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric]
- (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = {
+ def prop[
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ Out: TypedEncoder: Numeric
+ ](data: List[X2[A, B]]
+ )(implicit
+ summable: CatalystSummable[B, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
- val received = dataset.rollup(A, B).agg(count()).collect().run().toVector.sortBy(_._3)
- val expected = dataset.dataset.rollup("a", "b").count().collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))).sortBy(_._3)
+ val received =
+ dataset.rollup(A, B).agg(count()).collect().run().toVector.sortBy(_._3)
+ val expected = dataset.dataset
+ .rollup("a", "b")
+ .count()
+ .collect()
+ .toVector
+ .map(row =>
+ (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))
+ )
+ .sortBy(_._3)
received ?= expected
}
@@ -41,15 +64,27 @@ class RollupTests extends TypedDatasetSuite {
}
test("rollup('a).agg(sum('b)") {
- def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric]
- (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = {
+ def prop[
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ Out: TypedEncoder: Numeric
+ ](data: List[X2[A, B]]
+ )(implicit
+ summable: CatalystSummable[B, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
- val received = dataset.rollup(A).agg(sum(B)).collect().run().toVector.sortBy(_._2)
- val expected = dataset.dataset.rollup("a").sum("b").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))).sortBy(_._2)
+ val received =
+ dataset.rollup(A).agg(sum(B)).collect().run().toVector.sortBy(_._2)
+ val expected = dataset.dataset
+ .rollup("a")
+ .sum("b")
+ .collect()
+ .toVector
+ .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1)))
+ .sortBy(_._2)
received ?= expected
}
@@ -58,15 +93,22 @@ class RollupTests extends TypedDatasetSuite {
}
test("rollup('a).mapGroups('a, sum('b))") {
- def prop[A: TypedEncoder : Ordering, B: TypedEncoder : Numeric]
- (data: List[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric](
+ data: List[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
- val received = dataset.rollup(A)
- .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) }
- .collect().run().toVector.sortBy(_._1)
- val expected = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1)
+ val received = dataset
+ .rollup(A)
+ .deserialized
+ .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) }
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
+ val expected =
+ data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1)
received ?= expected
}
@@ -76,16 +118,16 @@ class RollupTests extends TypedDatasetSuite {
test("rollup('a).agg(sum('b), sum('c)) to rollup('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder,
- C: TypedEncoder,
- OutB: TypedEncoder : Numeric,
- OutC: TypedEncoder : Numeric
- ](data: List[X3[A, B, C]])(
- implicit
- summableB: CatalystSummable[B, OutB],
- summableC: CatalystSummable[C, OutC]
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ OutB: TypedEncoder: Numeric,
+ OutC: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ )(implicit
+ summableB: CatalystSummable[B, OutB],
+ summableC: CatalystSummable[C, OutC]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -94,37 +136,91 @@ class RollupTests extends TypedDatasetSuite {
val framelessSumBC = dataset
.rollup(A)
.agg(sum(B), sum(C))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBC = dataset.dataset.rollup("a").sum("b", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2)))
+ val sparkSumBC = dataset.dataset
+ .rollup("a")
+ .sum("b", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2))
+ )
.sortBy(_._1)
val framelessSumBCB = dataset
.rollup(A)
.agg(sum(B), sum(C), sum(B))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBCB = dataset.dataset.rollup("a").sum("b", "c", "b").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3)))
+ val sparkSumBCB = dataset.dataset
+ .rollup("a")
+ .sum("b", "c", "b")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ row.getAs[OutB](1),
+ row.getAs[OutC](2),
+ row.getAs[OutB](3)
+ )
+ )
.sortBy(_._1)
val framelessSumBCBC = dataset
.rollup(A)
.agg(sum(B), sum(C), sum(B), sum(C))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBCBC = dataset.dataset.rollup("a").sum("b", "c", "b", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4)))
+ val sparkSumBCBC = dataset.dataset
+ .rollup("a")
+ .sum("b", "c", "b", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ row.getAs[OutB](1),
+ row.getAs[OutC](2),
+ row.getAs[OutB](3),
+ row.getAs[OutC](4)
+ )
+ )
.sortBy(_._1)
val framelessSumBCBCB = dataset
.rollup(A)
.agg(sum(B), sum(C), sum(B), sum(C), sum(B))
- .collect().run().toVector.sortBy(_._1)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._1)
- val sparkSumBCBCB = dataset.dataset.rollup("a").sum("b", "c", "b", "c", "b").collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4), row.getAs[OutB](5)))
+ val sparkSumBCBCB = dataset.dataset
+ .rollup("a")
+ .sum("b", "c", "b", "c", "b")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ row.getAs[OutB](1),
+ row.getAs[OutC](2),
+ row.getAs[OutB](3),
+ row.getAs[OutC](4),
+ row.getAs[OutB](5)
+ )
+ )
.sortBy(_._1)
(framelessSumBC ?= sparkSumBC)
@@ -138,17 +234,17 @@ class RollupTests extends TypedDatasetSuite {
test("rollup('a, 'b).agg(sum('c), sum('d))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder,
- D: TypedEncoder,
- OutC: TypedEncoder : Numeric,
- OutD: TypedEncoder : Numeric
- ](data: List[X4[A, B, C, D]])(
- implicit
- summableC: CatalystSummable[C, OutC],
- summableD: CatalystSummable[D, OutD]
- ): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder,
+ D: TypedEncoder,
+ OutC: TypedEncoder: Numeric,
+ OutD: TypedEncoder: Numeric
+ ](data: List[X4[A, B, C, D]]
+ )(implicit
+ summableC: CatalystSummable[C, OutC],
+ summableD: CatalystSummable[D, OutD]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
@@ -158,11 +254,24 @@ class RollupTests extends TypedDatasetSuite {
val framelessSumByAB = dataset
.rollup(A, B)
.agg(sum(C), sum(D))
- .collect().run().toVector.sortBy(_._2)
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._2)
val sparkSumByAB = dataset.dataset
- .rollup("a", "b").sum("c", "d").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutD](3)))
+ .rollup("a", "b")
+ .sum("c", "d")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutD](3)
+ )
+ )
.sortBy(_._2)
framelessSumByAB ?= sparkSumByAB
@@ -173,76 +282,135 @@ class RollupTests extends TypedDatasetSuite {
test("rollup('a, 'b).agg(sum('c)) to rollup('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder,
- OutC: TypedEncoder: Numeric
- ](data: List[X3[A, B, C]])(implicit summableC: CatalystSummable[C, OutC]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder,
+ OutC: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ )(implicit
+ summableC: CatalystSummable[C, OutC]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
val C = dataset.col[C]('c)
- val framelessSumC = dataset
- .rollup(A, B)
- .agg(sum(C))
- .collect().run().toVector
- .sortBy(_._2)
+ val framelessSumC =
+ dataset.rollup(A, B).agg(sum(C)).collect().run().toVector.sortBy(_._2)
val sparkSumC = dataset.dataset
- .rollup("a", "b").sum("c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2)))
+ .rollup("a", "b")
+ .sum("c")
+ .collect()
+ .toVector
+ .map(row =>
+ (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2))
+ )
.sortBy(_._2)
val framelessSumCC = dataset
.rollup(A, B)
.agg(sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCC = dataset.dataset
- .rollup("a", "b").sum("c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3)))
+ .rollup("a", "b")
+ .sum("c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3)
+ )
+ )
.sortBy(_._2)
val framelessSumCCC = dataset
.rollup(A, B)
.agg(sum(C), sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCCC = dataset.dataset
- .rollup("a", "b").sum("c", "c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4)))
+ .rollup("a", "b")
+ .sum("c", "c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3),
+ row.getAs[OutC](4)
+ )
+ )
.sortBy(_._2)
val framelessSumCCCC = dataset
.rollup(A, B)
.agg(sum(C), sum(C), sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCCCC = dataset.dataset
- .rollup("a", "b").sum("c", "c", "c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5)))
+ .rollup("a", "b")
+ .sum("c", "c", "c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3),
+ row.getAs[OutC](4),
+ row.getAs[OutC](5)
+ )
+ )
.sortBy(_._2)
val framelessSumCCCCC = dataset
.rollup(A, B)
.agg(sum(C), sum(C), sum(C), sum(C), sum(C))
- .collect().run().toVector
+ .collect()
+ .run()
+ .toVector
.sortBy(_._2)
val sparkSumCCCCC = dataset.dataset
- .rollup("a", "b").sum("c", "c", "c", "c", "c").collect().toVector
- .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5), row.getAs[OutC](6)))
+ .rollup("a", "b")
+ .sum("c", "c", "c", "c", "c")
+ .collect()
+ .toVector
+ .map(row =>
+ (
+ Option(row.getAs[A](0)),
+ Option(row.getAs[B](1)),
+ row.getAs[OutC](2),
+ row.getAs[OutC](3),
+ row.getAs[OutC](4),
+ row.getAs[OutC](5),
+ row.getAs[OutC](6)
+ )
+ )
.sortBy(_._2)
(framelessSumC ?= sparkSumC) &&
- (framelessSumCC ?= sparkSumCC) &&
- (framelessSumCCC ?= sparkSumCCC) &&
- (framelessSumCCCC ?= sparkSumCCCC) &&
- (framelessSumCCCCC ?= sparkSumCCCCC)
+ (framelessSumCC ?= sparkSumCC) &&
+ (framelessSumCCC ?= sparkSumCCC) &&
+ (framelessSumCCCC ?= sparkSumCCCC) &&
+ (framelessSumCCCCC ?= sparkSumCCCCC)
}
check(forAll(prop[String, Long, Double, Double] _))
@@ -250,22 +418,30 @@ class RollupTests extends TypedDatasetSuite {
test("rollup('a, 'b).mapGroups('a, 'b, sum('c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder : Numeric
- ](data: List[X3[A, B, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Numeric
+ ](data: List[X3[A, B, C]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val B = dataset.col[B]('b)
val framelessSumByAB = dataset
.rollup(A, B)
- .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) }
- .collect().run().toVector.sortBy(x => (x._1, x._2))
-
- val sumByAB = data.groupBy(x => (x.a, x.b))
+ .deserialized
+ .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) }
+ .collect()
+ .run()
+ .toVector
+ .sortBy(x => (x._1, x._2))
+
+ val sumByAB = data
+ .groupBy(x => (x.a, x.b))
.mapValues { xs => xs.map(_.c).sum }
- .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2))
+ .toVector
+ .map { case ((a, b), c) => (a, b, c) }
+ .sortBy(x => (x._1, x._2))
framelessSumByAB ?= sumByAB
}
@@ -274,17 +450,19 @@ class RollupTests extends TypedDatasetSuite {
}
test("rollup('a).mapGroups(('a, toVector(('a, 'b))") {
- def prop[
- A: TypedEncoder: Ordering,
- B: TypedEncoder: Ordering
- ](data: Vector[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ data: Vector[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val datasetGrouped = dataset
.rollup(A)
- .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted))
- .collect().run().toMap
+ .deserialized
+ .mapGroups((a, xs) => (a, xs.toVector.sorted))
+ .collect()
+ .run()
+ .toMap
val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted }
@@ -297,21 +475,23 @@ class RollupTests extends TypedDatasetSuite {
}
test("rollup('a).flatMapGroups(('a, toVector(('a, 'b))") {
- def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering
- ](data: Vector[X2[A, B]]): Prop = {
+ def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering](
+ data: Vector[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
val datasetGrouped = dataset
.rollup(A)
- .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x)))
- .collect().run()
+ .deserialized
+ .flatMapGroups((a, xs) => xs.map(x => (a, x)))
+ .collect()
+ .run()
.sorted
val dataGrouped = data
- .groupBy(_.a).toSeq
+ .groupBy(_.a)
+ .toSeq
.flatMap { case (a, xs) => xs.map(x => (a, x)) }
.sorted
@@ -325,22 +505,26 @@ class RollupTests extends TypedDatasetSuite {
test("rollup('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") {
def prop[
- A: TypedEncoder : Ordering,
- B: TypedEncoder : Ordering,
- C: TypedEncoder : Ordering
- ](data: Vector[X3[A, B, C]]): Prop = {
+ A: TypedEncoder: Ordering,
+ B: TypedEncoder: Ordering,
+ C: TypedEncoder: Ordering
+ ](data: Vector[X3[A, B, C]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val cA = dataset.col[A]('a)
val cB = dataset.col[B]('b)
val datasetGrouped = dataset
.rollup(cA, cB)
- .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x)))
- .collect().run()
+ .deserialized
+ .flatMapGroups((a, xs) => xs.map(x => (a, x)))
+ .collect()
+ .run()
.sorted
val dataGrouped = data
- .groupBy(t => (t.a, t.b)).toSeq
+ .groupBy(t => (t.a, t.b))
+ .toSeq
.flatMap { case (a, xs) => xs.map(x => (a, x)) }
.sorted
@@ -353,18 +537,32 @@ class RollupTests extends TypedDatasetSuite {
}
test("rollupMany('a).agg(sum('b))") {
- def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric]
- (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = {
+ def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric](
+ data: List[X1[A]]
+ )(implicit
+ summable: CatalystSummable[A, Out]
+ ): Prop = {
val dataset = TypedDataset.create(data)
val A = dataset.col[A]('a)
- val received = dataset.rollupMany(A).agg(count[X1[A]]()).collect().run().toVector.sortBy(_._2)
- val expected = dataset.dataset.rollup("a").count().collect().toVector
- .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2)
+ val received = dataset
+ .rollupMany(A)
+ .agg(count[X1[A]]())
+ .collect()
+ .run()
+ .toVector
+ .sortBy(_._2)
+ val expected = dataset.dataset
+ .rollup("a")
+ .count()
+ .collect()
+ .toVector
+ .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1)))
+ .sortBy(_._2)
received ?= expected
}
check(forAll(prop[Int, Long] _))
}
-}
\ No newline at end of file
+}
diff --git a/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala b/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala
index 233a42aec..8b507c1a5 100644
--- a/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala
+++ b/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala
@@ -5,15 +5,16 @@ import org.scalacheck.Prop
import org.scalacheck.Prop._
import shapeless.test.illTyped
-
case class Foo(i: Int, j: Int, x: String)
case class Bar(i: Int, x: String)
case class InvalidFooProjectionType(i: Int, x: Boolean)
case class InvalidFooProjectionName(i: Int, xerr: String)
class SmartProjectTest extends TypedDatasetSuite {
+
// Lazy needed to prevent initialization anterior to the `beforeAll` hook
- lazy val dataset = TypedDataset.create(Foo(1, 2, "hi") :: Foo(2, 3, "there") :: Nil)
+ lazy val dataset =
+ TypedDataset.create(Foo(1, 2, "hi") :: Foo(2, 3, "there") :: Nil)
test("project Foo to Bar") {
assert(dataset.project[Bar].count().run() === 2)
@@ -25,28 +26,50 @@ class SmartProjectTest extends TypedDatasetSuite {
}
test("X4 to X1,X2,X3,X4 projections") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = {
+ def prop[
+ A: TypedEncoder,
+ B: TypedEncoder,
+ C: TypedEncoder,
+ D: TypedEncoder
+ ](data: Vector[X4[A, B, C, D]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
dataset.project[X4[A, B, C, D]].collect().run().toVector ?= data
- dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x => X3(x.a, x.b, x.c))
- dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x => X2(x.a, x.b))
+ dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x =>
+ X3(x.a, x.b, x.c)
+ )
+ dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x =>
+ X2(x.a, x.b)
+ )
dataset.project[X1[A]].collect().run().toVector ?= data.map(x => X1(x.a))
}
check(forAll(prop[Int, String, X1[String], Boolean] _))
check(forAll(prop[Short, Long, String, Boolean] _))
check(forAll(prop[Short, (Boolean, Boolean), String, (Int, Int)] _))
- check(forAll(prop[X2[String, Boolean], (Boolean, Boolean), String, Boolean] _))
- check(forAll(prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String, String] _))
+ check(
+ forAll(prop[X2[String, Boolean], (Boolean, Boolean), String, Boolean] _)
+ )
+ check(
+ forAll(
+ prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String, String] _
+ )
+ )
}
test("X3U to X1,X2,X3 projections") {
- def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](data: Vector[X3U[A, B, C]]): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](
+ data: Vector[X3U[A, B, C]]
+ ): Prop = {
val dataset = TypedDataset.create(data)
- dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x => X3(x.a, x.b, x.c))
- dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x => X2(x.a, x.b))
+ dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x =>
+ X3(x.a, x.b, x.c)
+ )
+ dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x =>
+ X2(x.a, x.b)
+ )
dataset.project[X1[A]].collect().run().toVector ?= data.map(x => X1(x.a))
}
@@ -54,6 +77,8 @@ class SmartProjectTest extends TypedDatasetSuite {
check(forAll(prop[Short, Long, String] _))
check(forAll(prop[Short, (Boolean, Boolean), String] _))
check(forAll(prop[X2[String, Boolean], (Boolean, Boolean), String] _))
- check(forAll(prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String] _))
+ check(
+ forAll(prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String] _)
+ )
}
}
diff --git a/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala
index b53000f09..28776ddb2 100644
--- a/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala
+++ b/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala
@@ -7,11 +7,17 @@ import org.scalacheck.Prop._
class FilterTests extends TypedDatasetSuite {
test("filter") {
- def prop[A: TypedEncoder](filterFunction: A => Boolean, data: Vector[A]): Prop =
- TypedDataset.create(data).
- deserialized.
- filter(filterFunction).
- collect().run().toVector =? data.filter(filterFunction)
+ def prop[A: TypedEncoder](
+ filterFunction: A => Boolean,
+ data: Vector[A]
+ ): Prop =
+ TypedDataset
+ .create(data)
+ .deserialized
+ .filter(filterFunction)
+ .collect()
+ .run()
+ .toVector =? data.filter(filterFunction)
check(forAll(prop[Int] _))
check(forAll(prop[String] _))
diff --git a/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala
index 7dcd0e4e3..409d631e1 100644
--- a/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala
+++ b/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala
@@ -7,11 +7,17 @@ import org.scalacheck.Prop._
class FlatMapTests extends TypedDatasetSuite {
test("flatMap") {
- def prop[A: TypedEncoder, B: TypedEncoder](flatMapFunction: A => Vector[B], data: Vector[A]): Prop =
- TypedDataset.create(data).
- deserialized.
- flatMap(flatMapFunction).
- collect().run().toVector =? data.flatMap(flatMapFunction)
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ flatMapFunction: A => Vector[B],
+ data: Vector[A]
+ ): Prop =
+ TypedDataset
+ .create(data)
+ .deserialized
+ .flatMap(flatMapFunction)
+ .collect()
+ .run()
+ .toVector =? data.flatMap(flatMapFunction)
check(forAll(prop[Int, Int] _))
check(forAll(prop[Int, String] _))
diff --git a/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala
index 06ba04943..def7c9614 100644
--- a/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala
+++ b/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala
@@ -7,12 +7,18 @@ import org.scalacheck.Prop._
class MapPartitionsTests extends TypedDatasetSuite {
test("mapPartitions") {
- def prop[A: TypedEncoder, B: TypedEncoder](mapFunction: A => B, data: Vector[A]): Prop = {
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ mapFunction: A => B,
+ data: Vector[A]
+ ): Prop = {
val lifted: Iterator[A] => Iterator[B] = _.map(mapFunction)
- TypedDataset.create(data).
- deserialized.
- mapPartitions(lifted).
- collect().run().toVector =? data.map(mapFunction)
+ TypedDataset
+ .create(data)
+ .deserialized
+ .mapPartitions(lifted)
+ .collect()
+ .run()
+ .toVector =? data.map(mapFunction)
}
check(forAll(prop[Int, Int] _))
diff --git a/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala
index f7cc0fad0..3cad10439 100644
--- a/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala
+++ b/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala
@@ -7,11 +7,17 @@ import org.scalacheck.Prop._
class MapTests extends TypedDatasetSuite {
test("map") {
- def prop[A: TypedEncoder, B: TypedEncoder](mapFunction: A => B, data: Vector[A]): Prop =
- TypedDataset.create(data).
- deserialized.
- map(mapFunction).
- collect().run().toVector =? data.map(mapFunction)
+ def prop[A: TypedEncoder, B: TypedEncoder](
+ mapFunction: A => B,
+ data: Vector[A]
+ ): Prop =
+ TypedDataset
+ .create(data)
+ .deserialized
+ .map(mapFunction)
+ .collect()
+ .run()
+ .toVector =? data.map(mapFunction)
check(forAll(prop[Int, Int] _))
check(forAll(prop[Int, String] _))
diff --git a/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala
index 01a074950..934475bb5 100644
--- a/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala
+++ b/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala
@@ -6,10 +6,16 @@ import org.scalacheck.Prop
import org.scalacheck.Prop._
class ReduceTests extends TypedDatasetSuite {
- def prop[A: TypedEncoder](reduceFunction: (A, A) => A)(data: Vector[A]): Prop =
- TypedDataset.create(data).
- deserialized.
- reduceOption(reduceFunction).run() =? data.reduceOption(reduceFunction)
+
+ def prop[A: TypedEncoder](
+ reduceFunction: (A, A) => A
+ )(data: Vector[A]
+ ): Prop =
+ TypedDataset
+ .create(data)
+ .deserialized
+ .reduceOption(reduceFunction)
+ .run() =? data.reduceOption(reduceFunction)
test("reduce Int") {
check(forAll(prop[Int](_ + _) _))
diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala
index 82ff375c9..54972786e 100644
--- a/dataset/src/test/scala/frameless/package.scala
+++ b/dataset/src/test/scala/frameless/package.scala
@@ -1,9 +1,10 @@
import java.time.format.DateTimeFormatter
-import java.time.{LocalDateTime => JavaLocalDateTime}
+import java.time.{ LocalDateTime => JavaLocalDateTime }
-import org.scalacheck.{Arbitrary, Gen}
+import org.scalacheck.{ Arbitrary, Gen }
package object frameless {
+
/** Fixed decimal point to avoid precision problems specific to Spark */
implicit val arbBigDecimal: Arbitrary[BigDecimal] = Arbitrary {
for {
@@ -30,7 +31,10 @@ package object frameless {
}
// see issue with scalacheck non serializable Vector: https://github.com/rickynils/scalacheck/issues/315
- implicit def arbVector[A](implicit A: Arbitrary[A]): Arbitrary[Vector[A]] =
+ implicit def arbVector[A](
+ implicit
+ A: Arbitrary[A]
+ ): Arbitrary[Vector[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector))
def vectorGen[A: Arbitrary]: Gen[Vector[A]] = arbVector[A].arbitrary
@@ -42,7 +46,8 @@ package object frameless {
} yield new UdtEncodedClass(int, doubles.toArray)
}
- val dateTimeFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm")
+ val dateTimeFormatter: DateTimeFormatter =
+ DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm")
implicit val localDateArb: Arbitrary[JavaLocalDateTime] = Arbitrary {
for {
@@ -72,11 +77,10 @@ package object frameless {
def anyCauseHas(t: Throwable, f: Throwable => Boolean): Boolean =
if (f(t))
true
+ else if (t.getCause ne null)
+ anyCauseHas(t.getCause, f)
else
- if (t.getCause ne null)
- anyCauseHas(t.getCause, f)
- else
- false
+ false
/**
* Runs up to maxRuns and outputs the number of failures (times thrown)
@@ -85,11 +89,11 @@ package object frameless {
* @tparam T
* @return the last passing thunk, or null
*/
- def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T ={
+ def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T = {
var i = 0
var r = null.asInstanceOf[T]
var passed = 0
- while(i < maxRuns){
+ while (i < maxRuns) {
i += 1
try {
r = thunk
@@ -98,29 +102,36 @@ package object frameless {
println(s"run $i successful")
}
} catch {
- case t: Throwable => System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}")
+ case t: Throwable =>
+ System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}")
}
}
if (passed != maxRuns) {
- System.err.println(s"had ${maxRuns - passed} failures out of $maxRuns runs")
+ System.err.println(
+ s"had ${maxRuns - passed} failures out of $maxRuns runs"
+ )
}
r
}
- /**
+ /**
* Runs a given thunk up to maxRuns times, restarting the thunk if tolerantOf the thrown Throwable is true
* @param tolerantOf
* @param maxRuns default of 20
* @param thunk
* @return either a successful run result or the last error will be thrown
*/
- def tolerantRun[T](tolerantOf: Throwable => Boolean, maxRuns: Int = 20)(thunk: => T): T ={
+ def tolerantRun[T](
+ tolerantOf: Throwable => Boolean,
+ maxRuns: Int = 20
+ )(thunk: => T
+ ): T = {
var passed = false
var i = 0
var res: T = null.asInstanceOf[T]
var thrown: Throwable = null
- while((i < maxRuns) && !passed) {
+ while ((i < maxRuns) && !passed) {
try {
i += 1
res = thunk
diff --git a/dataset/src/test/scala/frameless/sql/package.scala b/dataset/src/test/scala/frameless/sql/package.scala
index fcb45b03d..936d78042 100644
--- a/dataset/src/test/scala/frameless/sql/package.scala
+++ b/dataset/src/test/scala/frameless/sql/package.scala
@@ -1,16 +1,18 @@
package frameless
import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.expressions.{And, Or}
+import org.apache.spark.sql.catalyst.expressions.{ And, Or }
package object sql {
+
implicit class ExpressionOps(val self: Expression) extends AnyVal {
+
def toList: List[Expression] = {
def rec(expr: Expression, acc: List[Expression]): List[Expression] = {
expr match {
case And(left, right) => rec(left, rec(right, acc))
- case Or(left, right) => rec(left, rec(right, acc))
- case e => e +: acc
+ case Or(left, right) => rec(left, rec(right, acc))
+ case e => e +: acc
}
}
diff --git a/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala b/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala
index 8555d1809..c4cce0b14 100644
--- a/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala
+++ b/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala
@@ -11,30 +11,36 @@ import org.scalatest.Assertion
import org.scalatest.matchers.should.Matchers
trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self =>
+
protected lazy val path: String = {
val tmpDir = System.getProperty("java.io.tmpdir")
s"$tmpDir/${self.getClass.getName}"
}
- def withDataset[A: TypedEncoder: CatalystOrdered](payload: A)(f: TypedDataset[A] => Assertion): Assertion = {
+ def withDataset[A: TypedEncoder: CatalystOrdered](
+ payload: A
+ )(f: TypedDataset[A] => Assertion
+ ): Assertion = {
TypedDataset.create(Seq(payload)).write.mode("overwrite").parquet(path)
f(TypedDataset.createUnsafe[A](session.read.parquet(path)))
}
def predicatePushDownTest[A: TypedEncoder: CatalystOrdered](
- expected: X1[A],
- expectedPushDownFilters: List[Filter],
- planShouldNotContain: PartialFunction[Expression, Expression],
- op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
- ): Assertion = {
+ expected: X1[A],
+ expectedPushDownFilters: List[Filter],
+ planShouldNotContain: PartialFunction[Expression, Expression],
+ op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
+ ): Assertion = {
withDataset(expected) { dataset =>
val ds = dataset.filter(op(dataset('a)))
val actualPushDownFilters = pushDownFilters(ds)
- val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList)
+ val optimizedPlan = ds.queryExecution.optimizedPlan.collect {
+ case logical.Filter(condition, _) => condition
+ }.flatMap(_.toList)
// check the optimized plan
- optimizedPlan.collectFirst(planShouldNotContain) should be (empty)
+ optimizedPlan.collectFirst(planShouldNotContain) should be(empty)
// compare filters
actualPushDownFilters shouldBe expectedPushDownFilters
@@ -53,18 +59,22 @@ trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self =>
if (sparkPlan.children.isEmpty) // assume it's AQE
sparkPlan match {
case aq: AdaptiveSparkPlanExec => aq.initialPlan
- case _ => sparkPlan
+ case _ => sparkPlan
}
else
sparkPlan
initialPlan.collect {
case fs: FileSourceScanExec =>
- import scala.reflect.runtime.{universe => ru}
+ import scala.reflect.runtime.{ universe => ru }
val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader)
val instanceMirror = runtimeMirror.reflect(fs)
- val getter = ru.typeOf[FileSourceScanExec].member(ru.TermName("pushedDownFilters")).asTerm.getter
+ val getter = ru
+ .typeOf[FileSourceScanExec]
+ .member(ru.TermName("pushedDownFilters"))
+ .asTerm
+ .getter
val m = instanceMirror.reflectMethod(getter.asMethod)
val res = m.apply(fs).asInstanceOf[Seq[Filter]]
diff --git a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala
index 5108ed581..51839e0a3 100644
--- a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala
+++ b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala
@@ -9,26 +9,37 @@ class FramelessSyntaxTests extends TypedDatasetSuite {
// Hide the implicit SparkDelay[Job] on TypedDatasetSuite to avoid ambiguous implicits
override val sparkDelay = null
- def prop[A, B](data: Vector[X2[A, B]])(
- implicit ev: TypedEncoder[X2[A, B]]
- ): Prop = {
+ def prop[A, B](
+ data: Vector[X2[A, B]]
+ )(implicit
+ ev: TypedEncoder[X2[A, B]]
+ ): Prop = {
val dataset = TypedDataset.create(data).dataset
val dataframe = dataset.toDF()
val typedDataset = dataset.typed
val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]]
- typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame.collect().run().toVector
+ typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame
+ .collect()
+ .run()
+ .toVector
}
test("dataset typed - toTyped") {
- def prop[A, B](data: Vector[X2[A, B]])(
- implicit ev: TypedEncoder[X2[A, B]]
- ): Prop = {
- val dataset = session.createDataset(data)(TypedExpressionEncoder(ev)).typed
+ def prop[A, B](
+ data: Vector[X2[A, B]]
+ )(implicit
+ ev: TypedEncoder[X2[A, B]]
+ ): Prop = {
+ val dataset =
+ session.createDataset(data)(TypedExpressionEncoder(ev)).typed
val dataframe = dataset.toDF()
- dataset.collect().run().toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector
+ dataset
+ .collect()
+ .run()
+ .toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector
}
check(forAll(prop[Int, String] _))
@@ -38,8 +49,13 @@ class FramelessSyntaxTests extends TypedDatasetSuite {
test("frameless typed column and aggregate") {
def prop[A: TypedEncoder](a: A, b: A): Prop = {
val d = TypedDataset.create((a, b) :: Nil)
- (d.select(d('_1).untyped.typedColumn).collect().run ?= d.select(d('_1)).collect().run).&&(
- d.agg(first(d('_1))).collect().run() ?= d.agg(first(d('_1)).untyped.typedAggregate).collect().run()
+ (d.select(d('_1).untyped.typedColumn)
+ .collect()
+ .run ?= d.select(d('_1)).collect().run).&&(
+ d.agg(first(d('_1))).collect().run() ?= d
+ .agg(first(d('_1)).untyped.typedAggregate)
+ .collect()
+ .run()
)
}
diff --git a/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala b/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala
index a28ad0820..c6e4b8c29 100644
--- a/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala
+++ b/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala
@@ -3,5 +3,11 @@ package org.apache.hadoop.fs.local
import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem
import org.apache.hadoop.fs.DelegateToFileSystem
-class StreamingFS(uri: java.net.URI, conf: org.apache.hadoop.conf.Configuration) extends
- DelegateToFileSystem(uri, new BareLocalFileSystem(), conf, "file", false) {}
+class StreamingFS(uri: java.net.URI, conf: org.apache.hadoop.conf.Configuration)
+ extends DelegateToFileSystem(
+ uri,
+ new BareLocalFileSystem(),
+ conf,
+ "file",
+ false
+ ) {}
diff --git a/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala b/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala
index c44ac4d08..6224dc78f 100644
--- a/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala
+++ b/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala
@@ -3,10 +3,17 @@ package frameless.sql.rules
import frameless._
import frameless.sql._
import frameless.functions.Lit
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant}
-import org.apache.spark.sql.sources.{Filter, IsNotNull}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{
+ currentTimestamp,
+ microsToInstant
+}
+import org.apache.spark.sql.sources.{ Filter, IsNotNull }
import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericRowWithSchema}
+import org.apache.spark.sql.catalyst.expressions.{
+ Cast,
+ Expression,
+ GenericRowWithSchema
+}
import java.time.Instant
import org.apache.spark.sql.catalyst.plans.logical
@@ -45,7 +52,10 @@ class FramelessLitPushDownTests extends SQLRulesSuite {
test("struct push-down") {
type Payload = X4[Int, Int, Int, Int]
val expectedStructure = X1(X4(1, 2, 3, 4))
- val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema)
+ val expected = new GenericRowWithSchema(
+ Array(1, 2, 3, 4),
+ TypedExpressionEncoder[Payload].schema
+ )
val expectedPushDownFilters = List(IsNotNull("a"))
predicatePushDownTest[Payload](
@@ -58,16 +68,18 @@ class FramelessLitPushDownTests extends SQLRulesSuite {
}
override def predicatePushDownTest[A: TypedEncoder: CatalystOrdered](
- expected: X1[A],
- expectedPushDownFilters: List[Filter],
- planShouldContain: PartialFunction[Expression, Expression],
- op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
- ): Assertion = {
+ expected: X1[A],
+ expectedPushDownFilters: List[Filter],
+ planShouldContain: PartialFunction[Expression, Expression],
+ op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean]
+ ): Assertion = {
withDataset(expected) { dataset =>
val ds = dataset.filter(op(dataset('a)))
val actualPushDownFilters = pushDownFilters(ds)
- val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList)
+ val optimizedPlan = ds.queryExecution.optimizedPlan.collect {
+ case logical.Filter(condition, _) => condition
+ }.flatMap(_.toList)
// check the optimized plan
optimizedPlan.collectFirst(planShouldContain) should not be (empty)
diff --git a/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala
index 36a443fb5..93e723c2e 100644
--- a/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala
+++ b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala
@@ -2,8 +2,11 @@ package frameless.sql.rules
import frameless._
import frameless.functions.Lit
-import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant}
-import org.apache.spark.sql.sources.{EqualTo, GreaterThanOrEqual, IsNotNull}
+import org.apache.spark.sql.catalyst.util.DateTimeUtils.{
+ currentTimestamp,
+ microsToInstant
+}
+import org.apache.spark.sql.sources.{ EqualTo, GreaterThanOrEqual, IsNotNull }
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import java.time.Instant
@@ -14,7 +17,8 @@ class FramelessLitPushDownTests extends SQLRulesSuite {
test("java.sql.Timestamp push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(SQLTimestamp(now))
- val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected))
+ val expectedPushDownFilters =
+ List(IsNotNull("a"), GreaterThanOrEqual("a", expected))
predicatePushDownTest[SQLTimestamp](
expectedStructure,
@@ -27,7 +31,8 @@ class FramelessLitPushDownTests extends SQLRulesSuite {
test("java.time.Instant push-down") {
val expected = java.sql.Timestamp.from(microsToInstant(now))
val expectedStructure = X1(microsToInstant(now))
- val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected))
+ val expectedPushDownFilters =
+ List(IsNotNull("a"), GreaterThanOrEqual("a", expected))
predicatePushDownTest[Instant](
expectedStructure,
@@ -40,7 +45,10 @@ class FramelessLitPushDownTests extends SQLRulesSuite {
test("struct push-down") {
type Payload = X4[Int, Int, Int, Int]
val expectedStructure = X1(X4(1, 2, 3, 4))
- val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema)
+ val expected = new GenericRowWithSchema(
+ Array(1, 2, 3, 4),
+ TypedExpressionEncoder[Payload].schema
+ )
val expectedPushDownFilters = List(IsNotNull("a"), EqualTo("a", expected))
predicatePushDownTest[Payload](
diff --git a/ml/src/main/scala/frameless/ml/TypedEstimator.scala b/ml/src/main/scala/frameless/ml/TypedEstimator.scala
index 2628d234a..6c6ca6029 100644
--- a/ml/src/main/scala/frameless/ml/TypedEstimator.scala
+++ b/ml/src/main/scala/frameless/ml/TypedEstimator.scala
@@ -2,19 +2,20 @@ package frameless
package ml
import frameless.ops.SmartProject
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.{ Estimator, Model }
/**
- * A TypedEstimator fits models to data.
- */
+ * A TypedEstimator fits models to data.
+ */
trait TypedEstimator[Inputs, Outputs, M <: Model[M]] {
val estimator: Estimator[M]
- def fit[T, F[_]](ds: TypedDataset[T])(
- implicit
- smartProject: SmartProject[T, Inputs],
- F: SparkDelay[F]
- ): F[AppendTransformer[Inputs, Outputs, M]] = {
+ def fit[T, F[_]](
+ ds: TypedDataset[T]
+ )(implicit
+ smartProject: SmartProject[T, Inputs],
+ F: SparkDelay[F]
+ ): F[AppendTransformer[Inputs, Outputs, M]] = {
implicit val sparkSession = ds.dataset.sparkSession
F.delay {
val inputDs = smartProject.apply(ds)
diff --git a/ml/src/main/scala/frameless/ml/TypedTransformer.scala b/ml/src/main/scala/frameless/ml/TypedTransformer.scala
index 99edc70e6..68dbd836c 100644
--- a/ml/src/main/scala/frameless/ml/TypedTransformer.scala
+++ b/ml/src/main/scala/frameless/ml/TypedTransformer.scala
@@ -3,31 +3,34 @@ package ml
import frameless.ops.SmartProject
import org.apache.spark.ml.Transformer
-import shapeless.{Generic, HList}
-import shapeless.ops.hlist.{Prepend, Tupler}
+import shapeless.{ Generic, HList }
+import shapeless.ops.hlist.{ Prepend, Tupler }
/**
- * A TypedTransformer transforms one TypedDataset into another.
- */
+ * A TypedTransformer transforms one TypedDataset into another.
+ */
sealed trait TypedTransformer
/**
- * An AppendTransformer `transform` method takes as input a TypedDataset containing `Inputs` and
- * return a TypedDataset with `Outputs` columns appended to the input TypedDataset.
- */
-trait AppendTransformer[Inputs, Outputs, InnerTransformer <: Transformer] extends TypedTransformer {
+ * An AppendTransformer `transform` method takes as input a TypedDataset containing `Inputs` and
+ * return a TypedDataset with `Outputs` columns appended to the input TypedDataset.
+ */
+trait AppendTransformer[Inputs, Outputs, InnerTransformer <: Transformer]
+ extends TypedTransformer {
val transformer: InnerTransformer
- def transform[T, TVals <: HList, OutputsVals <: HList, OutVals <: HList, Out](ds: TypedDataset[T])(
- implicit
- i0: SmartProject[T, Inputs],
- i1: Generic.Aux[T, TVals],
- i2: Generic.Aux[Outputs, OutputsVals],
- i3: Prepend.Aux[TVals, OutputsVals, OutVals],
- i4: Tupler.Aux[OutVals, Out],
- i5: TypedEncoder[Out]
- ): TypedDataset[Out] = {
- val transformed = transformer.transform(ds.dataset).as[Out](TypedExpressionEncoder[Out])
+ def transform[T, TVals <: HList, OutputsVals <: HList, OutVals <: HList, Out](
+ ds: TypedDataset[T]
+ )(implicit
+ i0: SmartProject[T, Inputs],
+ i1: Generic.Aux[T, TVals],
+ i2: Generic.Aux[Outputs, OutputsVals],
+ i3: Prepend.Aux[TVals, OutputsVals, OutVals],
+ i4: Tupler.Aux[OutVals, Out],
+ i5: TypedEncoder[Out]
+ ): TypedDataset[Out] = {
+ val transformed =
+ transformer.transform(ds.dataset).as[Out](TypedExpressionEncoder[Out])
TypedDataset.create[Out](transformed)
}
diff --git a/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala b/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala
index f6efcceaf..1bd757478 100644
--- a/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala
+++ b/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala
@@ -4,48 +4,82 @@ package classification
import frameless.ml.internals.TreesInputsChecker
import frameless.ml.params.trees.FeatureSubsetStrategy
-import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
+import org.apache.spark.ml.classification.{
+ RandomForestClassificationModel,
+ RandomForestClassifier
+}
import org.apache.spark.ml.linalg.Vector
/**
- * Random Forest learning algorithm for
- * classification.
- * It supports both binary and multiclass labels, as well as both continuous and categorical
- * features.
- */
-final class TypedRandomForestClassifier[Inputs] private[ml](
- rf: RandomForestClassifier,
- labelCol: String,
- featuresCol: String
-) extends TypedEstimator[Inputs, TypedRandomForestClassifier.Outputs, RandomForestClassificationModel] {
+ * Random Forest learning algorithm for
+ * classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+final class TypedRandomForestClassifier[Inputs] private[ml] (
+ rf: RandomForestClassifier,
+ labelCol: String,
+ featuresCol: String)
+ extends TypedEstimator[
+ Inputs,
+ TypedRandomForestClassifier.Outputs,
+ RandomForestClassificationModel
+ ] {
val estimator: RandomForestClassifier =
- rf
- .setLabelCol(labelCol)
+ rf.setLabelCol(labelCol)
.setFeaturesCol(featuresCol)
.setPredictionCol(AppendTransformer.tempColumnName)
.setRawPredictionCol(AppendTransformer.tempColumnName2)
.setProbabilityCol(AppendTransformer.tempColumnName3)
- def setNumTrees(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setNumTrees(value))
- def setMaxDepth(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxDepth(value))
- def setMinInfoGain(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInfoGain(value))
- def setMinInstancesPerNode(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInstancesPerNode(value))
- def setMaxMemoryInMB(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxMemoryInMB(value))
- def setSubsamplingRate(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setSubsamplingRate(value))
- def setFeatureSubsetStrategy(value: FeatureSubsetStrategy): TypedRandomForestClassifier[Inputs] =
+ def setNumTrees(value: Int): TypedRandomForestClassifier[Inputs] =
+ copy(rf.setNumTrees(value))
+
+ def setMaxDepth(value: Int): TypedRandomForestClassifier[Inputs] =
+ copy(rf.setMaxDepth(value))
+
+ def setMinInfoGain(value: Double): TypedRandomForestClassifier[Inputs] =
+ copy(rf.setMinInfoGain(value))
+
+ def setMinInstancesPerNode(value: Int): TypedRandomForestClassifier[Inputs] =
+ copy(rf.setMinInstancesPerNode(value))
+
+ def setMaxMemoryInMB(value: Int): TypedRandomForestClassifier[Inputs] =
+ copy(rf.setMaxMemoryInMB(value))
+
+ def setSubsamplingRate(value: Double): TypedRandomForestClassifier[Inputs] =
+ copy(rf.setSubsamplingRate(value))
+
+ def setFeatureSubsetStrategy(
+ value: FeatureSubsetStrategy
+ ): TypedRandomForestClassifier[Inputs] =
copy(rf.setFeatureSubsetStrategy(value.sparkValue))
- def setMaxBins(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxBins(value))
- private def copy(newRf: RandomForestClassifier): TypedRandomForestClassifier[Inputs] =
+ def setMaxBins(value: Int): TypedRandomForestClassifier[Inputs] =
+ copy(rf.setMaxBins(value))
+
+ private def copy(
+ newRf: RandomForestClassifier
+ ): TypedRandomForestClassifier[Inputs] =
new TypedRandomForestClassifier[Inputs](newRf, labelCol, featuresCol)
}
object TypedRandomForestClassifier {
- case class Outputs(rawPrediction: Vector, probability: Vector, prediction: Double)
- def apply[Inputs](implicit inputsChecker: TreesInputsChecker[Inputs]): TypedRandomForestClassifier[Inputs] = {
- new TypedRandomForestClassifier(new RandomForestClassifier(), inputsChecker.labelCol, inputsChecker.featuresCol)
+ case class Outputs(
+ rawPrediction: Vector,
+ probability: Vector,
+ prediction: Double)
+
+ def apply[Inputs](
+ implicit
+ inputsChecker: TreesInputsChecker[Inputs]
+ ): TypedRandomForestClassifier[Inputs] = {
+ new TypedRandomForestClassifier(
+ new RandomForestClassifier(),
+ inputsChecker.labelCol,
+ inputsChecker.featuresCol
+ )
}
}
-
diff --git a/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala b/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala
index 4a8c974b4..210c9e589 100644
--- a/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala
+++ b/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala
@@ -3,39 +3,46 @@ package ml
package classification
import frameless.ml.internals.VectorInputsChecker
-import org.apache.spark.ml.clustering.{BisectingKMeans, BisectingKMeansModel}
+import org.apache.spark.ml.clustering.{ BisectingKMeans, BisectingKMeansModel }
/**
- * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques"
- * by Steinbach, Karypis, and Kumar, with modification to fit Spark.
- * The algorithm starts from a single cluster that contains all points.
- * Iteratively it finds divisible clusters on the bottom level and bisects each of them using
- * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible.
- * The bisecting steps of clusters on the same level are grouped together to increase parallelism.
- * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters,
- * larger clusters get higher priority.
- *
- * @see
- * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques,
- * KDD Workshop on Text Mining, 2000.
- */
+ * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques"
+ * by Steinbach, Karypis, and Kumar, with modification to fit Spark.
+ * The algorithm starts from a single cluster that contains all points.
+ * Iteratively it finds divisible clusters on the bottom level and bisects each of them using
+ * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible.
+ * The bisecting steps of clusters on the same level are grouped together to increase parallelism.
+ * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters,
+ * larger clusters get higher priority.
+ *
+ * @see
+ * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques,
+ * KDD Workshop on Text Mining, 2000.
+ */
class TypedBisectingKMeans[Inputs] private[ml] (
- bkm: BisectingKMeans,
- featuresCol: String
-) extends TypedEstimator[Inputs,TypedBisectingKMeans.Output, BisectingKMeansModel]{
+ bkm: BisectingKMeans,
+ featuresCol: String)
+ extends TypedEstimator[
+ Inputs,
+ TypedBisectingKMeans.Output,
+ BisectingKMeansModel
+ ] {
+
val estimator: BisectingKMeans =
bkm
- .setFeaturesCol(featuresCol)
- .setPredictionCol(AppendTransformer.tempColumnName)
-
+ .setFeaturesCol(featuresCol)
+ .setPredictionCol(AppendTransformer.tempColumnName)
+
def setK(value: Int): TypedBisectingKMeans[Inputs] = copy(bkm.setK(value))
-
- def setMaxIter(value: Int): TypedBisectingKMeans[Inputs] = copy(bkm.setMaxIter(value))
+
+ def setMaxIter(value: Int): TypedBisectingKMeans[Inputs] =
+ copy(bkm.setMaxIter(value))
def setMinDivisibleClusterSize(value: Double): TypedBisectingKMeans[Inputs] =
copy(bkm.setMinDivisibleClusterSize(value))
-
- def setSeed(value: Long): TypedBisectingKMeans[Inputs] = copy(bkm.setSeed(value))
+
+ def setSeed(value: Long): TypedBisectingKMeans[Inputs] =
+ copy(bkm.setSeed(value))
private def copy(newBkm: BisectingKMeans): TypedBisectingKMeans[Inputs] =
new TypedBisectingKMeans[Inputs](newBkm, featuresCol)
@@ -44,6 +51,9 @@ class TypedBisectingKMeans[Inputs] private[ml] (
object TypedBisectingKMeans {
case class Output(prediction: Int)
- def apply[Inputs]()(implicit inputsChecker: VectorInputsChecker[Inputs]): TypedBisectingKMeans[Inputs] =
+ def apply[Inputs](
+ )(implicit
+ inputsChecker: VectorInputsChecker[Inputs]
+ ): TypedBisectingKMeans[Inputs] =
new TypedBisectingKMeans(new BisectingKMeans(), inputsChecker.featuresCol)
-}
\ No newline at end of file
+}
diff --git a/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala b/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala
index 1a32076a5..ed082a20b 100644
--- a/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala
+++ b/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala
@@ -4,27 +4,29 @@ package classification
import frameless.ml.internals.VectorInputsChecker
import frameless.ml.params.kmeans.KMeansInitMode
-import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.clustering.{ KMeans, KMeansModel }
/**
- * K-means clustering with support for k-means|| initialization proposed by Bahmani et al.
- *
- * @see Bahmani et al., Scalable k-means++.
- */
+ * K-means clustering with support for k-means|| initialization proposed by Bahmani et al.
+ *
+ * @see Bahmani et al., Scalable k-means++.
+ */
class TypedKMeans[Inputs] private[ml] (
- km: KMeans,
- featuresCol: String
-) extends TypedEstimator[Inputs,TypedKMeans.Output,KMeansModel] {
+ km: KMeans,
+ featuresCol: String)
+ extends TypedEstimator[Inputs, TypedKMeans.Output, KMeansModel] {
+
val estimator: KMeans =
- km
- .setFeaturesCol(featuresCol)
+ km.setFeaturesCol(featuresCol)
.setPredictionCol(AppendTransformer.tempColumnName)
def setK(value: Int): TypedKMeans[Inputs] = copy(km.setK(value))
- def setInitMode(value: KMeansInitMode): TypedKMeans[Inputs] = copy(km.setInitMode(value.sparkValue))
+ def setInitMode(value: KMeansInitMode): TypedKMeans[Inputs] =
+ copy(km.setInitMode(value.sparkValue))
- def setInitSteps(value: Int): TypedKMeans[Inputs] = copy(km.setInitSteps(value))
+ def setInitSteps(value: Int): TypedKMeans[Inputs] =
+ copy(km.setInitSteps(value))
def setMaxIter(value: Int): TypedKMeans[Inputs] = copy(km.setMaxIter(value))
@@ -32,14 +34,18 @@ class TypedKMeans[Inputs] private[ml] (
def setSeed(value: Long): TypedKMeans[Inputs] = copy(km.setSeed(value))
- private def copy(newKmeans: KMeans): TypedKMeans[Inputs] = new TypedKMeans[Inputs](newKmeans, featuresCol)
+ private def copy(newKmeans: KMeans): TypedKMeans[Inputs] =
+ new TypedKMeans[Inputs](newKmeans, featuresCol)
}
-object TypedKMeans{
+object TypedKMeans {
case class Output(prediction: Int)
- def apply[Inputs](implicit inputsChecker: VectorInputsChecker[Inputs]): TypedKMeans[Inputs] = {
+ def apply[Inputs](
+ implicit
+ inputsChecker: VectorInputsChecker[Inputs]
+ ): TypedKMeans[Inputs] = {
new TypedKMeans(new KMeans(), inputsChecker.featuresCol)
}
}
diff --git a/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala b/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala
index af2e9684a..f650ed678 100644
--- a/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala
+++ b/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala
@@ -6,14 +6,20 @@ import frameless.ml.internals.UnaryInputsChecker
import org.apache.spark.ml.feature.IndexToString
/**
- * A `TypedTransformer` that maps a column of indices back to a new column of corresponding
- * string values.
- * The index-string mapping must be supplied when creating the `TypedIndexToString`.
- *
- * @see `TypedStringIndexer` for converting strings into indices
- */
-final class TypedIndexToString[Inputs] private[ml](indexToString: IndexToString, inputCol: String)
- extends AppendTransformer[Inputs, TypedIndexToString.Outputs, IndexToString] {
+ * A `TypedTransformer` that maps a column of indices back to a new column of corresponding
+ * string values.
+ * The index-string mapping must be supplied when creating the `TypedIndexToString`.
+ *
+ * @see `TypedStringIndexer` for converting strings into indices
+ */
+final class TypedIndexToString[Inputs] private[ml] (
+ indexToString: IndexToString,
+ inputCol: String)
+ extends AppendTransformer[
+ Inputs,
+ TypedIndexToString.Outputs,
+ IndexToString
+ ] {
val transformer: IndexToString =
indexToString
@@ -25,8 +31,14 @@ final class TypedIndexToString[Inputs] private[ml](indexToString: IndexToString,
object TypedIndexToString {
case class Outputs(originalOutput: String)
- def apply[Inputs](labels: Array[String])
- (implicit inputsChecker: UnaryInputsChecker[Inputs, Double]): TypedIndexToString[Inputs] = {
- new TypedIndexToString[Inputs](new IndexToString().setLabels(labels), inputsChecker.inputCol)
+ def apply[Inputs](
+ labels: Array[String]
+ )(implicit
+ inputsChecker: UnaryInputsChecker[Inputs, Double]
+ ): TypedIndexToString[Inputs] = {
+ new TypedIndexToString[Inputs](
+ new IndexToString().setLabels(labels),
+ inputsChecker.inputCol
+ )
}
-}
\ No newline at end of file
+}
diff --git a/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala b/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala
index 7eba8e306..d398e3581 100644
--- a/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala
+++ b/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala
@@ -4,25 +4,34 @@ package feature
import frameless.ml.feature.TypedStringIndexer.HandleInvalid
import frameless.ml.internals.UnaryInputsChecker
-import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel}
+import org.apache.spark.ml.feature.{ StringIndexer, StringIndexerModel }
/**
- * A label indexer that maps a string column of labels to an ML column of label indices.
- * The indices are in [0, numLabels), ordered by label frequencies.
- * So the most frequent label gets index 0.
- *
- * @see `TypedIndexToString` for the inverse transformation
- */
-final class TypedStringIndexer[Inputs] private[ml](stringIndexer: StringIndexer, inputCol: String)
- extends TypedEstimator[Inputs, TypedStringIndexer.Outputs, StringIndexerModel] {
+ * A label indexer that maps a string column of labels to an ML column of label indices.
+ * The indices are in [0, numLabels), ordered by label frequencies.
+ * So the most frequent label gets index 0.
+ *
+ * @see `TypedIndexToString` for the inverse transformation
+ */
+final class TypedStringIndexer[Inputs] private[ml] (
+ stringIndexer: StringIndexer,
+ inputCol: String)
+ extends TypedEstimator[
+ Inputs,
+ TypedStringIndexer.Outputs,
+ StringIndexerModel
+ ] {
val estimator: StringIndexer = stringIndexer
.setInputCol(inputCol)
.setOutputCol(AppendTransformer.tempColumnName)
- def setHandleInvalid(value: HandleInvalid): TypedStringIndexer[Inputs] = copy(stringIndexer.setHandleInvalid(value.sparkValue))
+ def setHandleInvalid(value: HandleInvalid): TypedStringIndexer[Inputs] =
+ copy(stringIndexer.setHandleInvalid(value.sparkValue))
- private def copy(newStringIndexer: StringIndexer): TypedStringIndexer[Inputs] =
+ private def copy(
+ newStringIndexer: StringIndexer
+ ): TypedStringIndexer[Inputs] =
new TypedStringIndexer[Inputs](newStringIndexer, inputCol)
}
@@ -30,13 +39,17 @@ object TypedStringIndexer {
case class Outputs(indexedOutput: Double)
sealed abstract class HandleInvalid(val sparkValue: String)
+
object HandleInvalid {
case object Error extends HandleInvalid("error")
case object Skip extends HandleInvalid("skip")
case object Keep extends HandleInvalid("keep")
}
- def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, String]): TypedStringIndexer[Inputs] = {
+ def apply[Inputs](
+ implicit
+ inputsChecker: UnaryInputsChecker[Inputs, String]
+ ): TypedStringIndexer[Inputs] = {
new TypedStringIndexer[Inputs](new StringIndexer(), inputsChecker.inputCol)
}
-}
\ No newline at end of file
+}
diff --git a/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala b/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala
index d599011b3..4008462b9 100644
--- a/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala
+++ b/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala
@@ -4,17 +4,23 @@ package feature
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vector
-import shapeless.{HList, HNil, LabelledGeneric}
+import shapeless.{ HList, HNil, LabelledGeneric }
import shapeless.ops.hlist.ToTraversable
-import shapeless.ops.record.{Keys, Values}
+import shapeless.ops.record.{ Keys, Values }
import shapeless._
import scala.annotation.implicitNotFound
/**
- * A feature transformer that merges multiple columns into a vector column.
- */
-final class TypedVectorAssembler[Inputs] private[ml](vectorAssembler: VectorAssembler, inputCols: Array[String])
- extends AppendTransformer[Inputs, TypedVectorAssembler.Output, VectorAssembler] {
+ * A feature transformer that merges multiple columns into a vector column.
+ */
+final class TypedVectorAssembler[Inputs] private[ml] (
+ vectorAssembler: VectorAssembler,
+ inputCols: Array[String])
+ extends AppendTransformer[
+ Inputs,
+ TypedVectorAssembler.Output,
+ VectorAssembler
+ ] {
val transformer: VectorAssembler = vectorAssembler
.setInputCols(inputCols)
@@ -25,8 +31,14 @@ final class TypedVectorAssembler[Inputs] private[ml](vectorAssembler: VectorAsse
object TypedVectorAssembler {
case class Output(vector: Vector)
- def apply[Inputs](implicit inputsChecker: TypedVectorAssemblerInputsChecker[Inputs]): TypedVectorAssembler[Inputs] = {
- new TypedVectorAssembler(new VectorAssembler(), inputsChecker.inputCols.toArray)
+ def apply[Inputs](
+ implicit
+ inputsChecker: TypedVectorAssemblerInputsChecker[Inputs]
+ ): TypedVectorAssembler[Inputs] = {
+ new TypedVectorAssembler(
+ new VectorAssembler(),
+ inputsChecker.inputCols.toArray
+ )
}
}
@@ -38,32 +50,41 @@ private[ml] trait TypedVectorAssemblerInputsChecker[Inputs] {
}
private[ml] object TypedVectorAssemblerInputsChecker {
- implicit def checkInputs[Inputs, InputsRec <: HList, InputsKeys <: HList, InputsVals <: HList](
- implicit
- inputsGen: LabelledGeneric.Aux[Inputs, InputsRec],
- inputsKeys: Keys.Aux[InputsRec, InputsKeys],
- inputsKeysTraverse: ToTraversable.Aux[InputsKeys, Seq, Symbol],
- inputsValues: Values.Aux[InputsRec, InputsVals],
- inputsTypeCheck: TypedVectorAssemblerInputsValueChecker[InputsVals]
- ): TypedVectorAssemblerInputsChecker[Inputs] = new TypedVectorAssemblerInputsChecker[Inputs] {
- val inputCols: Seq[String] = inputsKeys.apply().to[Seq].map(_.name)
- }
+
+ implicit def checkInputs[
+ Inputs,
+ InputsRec <: HList,
+ InputsKeys <: HList,
+ InputsVals <: HList
+ ](implicit
+ inputsGen: LabelledGeneric.Aux[Inputs, InputsRec],
+ inputsKeys: Keys.Aux[InputsRec, InputsKeys],
+ inputsKeysTraverse: ToTraversable.Aux[InputsKeys, Seq, Symbol],
+ inputsValues: Values.Aux[InputsRec, InputsVals],
+ inputsTypeCheck: TypedVectorAssemblerInputsValueChecker[InputsVals]
+ ): TypedVectorAssemblerInputsChecker[Inputs] =
+ new TypedVectorAssemblerInputsChecker[Inputs] {
+ val inputCols: Seq[String] = inputsKeys.apply().to[Seq].map(_.name)
+ }
}
private[ml] trait TypedVectorAssemblerInputsValueChecker[InputsVals]
private[ml] object TypedVectorAssemblerInputsValueChecker {
+
implicit def hnilCheckInputsValue: TypedVectorAssemblerInputsValueChecker[HNil] =
new TypedVectorAssemblerInputsValueChecker[HNil] {}
implicit def hlistCheckInputsValueNumeric[H, T <: HList](
- implicit ch: CatalystNumeric[H],
- tt: TypedVectorAssemblerInputsValueChecker[T]
- ): TypedVectorAssemblerInputsValueChecker[H :: T] = new TypedVectorAssemblerInputsValueChecker[H :: T] {}
+ implicit
+ ch: CatalystNumeric[H],
+ tt: TypedVectorAssemblerInputsValueChecker[T]
+ ): TypedVectorAssemblerInputsValueChecker[H :: T] =
+ new TypedVectorAssemblerInputsValueChecker[H :: T] {}
implicit def hlistCheckInputsValueBoolean[T <: HList](
- implicit tt: TypedVectorAssemblerInputsValueChecker[T]
- ): TypedVectorAssemblerInputsValueChecker[Boolean :: T] = new TypedVectorAssemblerInputsValueChecker[Boolean :: T] {}
+ implicit
+ tt: TypedVectorAssemblerInputsValueChecker[T]
+ ): TypedVectorAssemblerInputsValueChecker[Boolean :: T] =
+ new TypedVectorAssemblerInputsValueChecker[Boolean :: T] {}
}
-
-
diff --git a/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala
index 995a3f961..edc645397 100644
--- a/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala
+++ b/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala
@@ -4,13 +4,13 @@ package internals
import org.apache.spark.ml.linalg._
import shapeless.ops.hlist.Length
-import shapeless.{HList, LabelledGeneric, Nat, Witness}
+import shapeless.{ HList, LabelledGeneric, Nat, Witness }
import scala.annotation.implicitNotFound
/**
- * Can be used for linear reg algorithm
- */
+ * Can be used for linear reg algorithm
+ */
@implicitNotFound(
msg = "Cannot prove that ${Inputs} is a valid input type. " +
"Input type must only contain a field of type Double (the label) and a field of type " +
@@ -25,18 +25,18 @@ trait LinearInputsChecker[Inputs] {
object LinearInputsChecker {
implicit def checkLinearInputs[
- Inputs,
- InputsRec <: HList,
- LabelK <: Symbol,
- FeaturesK <: Symbol](
- implicit
- i0: LabelledGeneric.Aux[Inputs, InputsRec],
- i1: Length.Aux[InputsRec, Nat._2],
- i2: SelectorByValue.Aux[InputsRec, Double, LabelK],
- i3: Witness.Aux[LabelK],
- i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
- i5: Witness.Aux[FeaturesK]
- ): LinearInputsChecker[Inputs] = {
+ Inputs,
+ InputsRec <: HList,
+ LabelK <: Symbol,
+ FeaturesK <: Symbol
+ ](implicit
+ i0: LabelledGeneric.Aux[Inputs, InputsRec],
+ i1: Length.Aux[InputsRec, Nat._2],
+ i2: SelectorByValue.Aux[InputsRec, Double, LabelK],
+ i3: Witness.Aux[LabelK],
+ i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
+ i5: Witness.Aux[FeaturesK]
+ ): LinearInputsChecker[Inputs] = {
new LinearInputsChecker[Inputs] {
val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name
val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name
@@ -45,25 +45,27 @@ object LinearInputsChecker {
}
implicit def checkLinearInputs2[
- Inputs,
- InputsRec <: HList,
- LabelK <: Symbol,
- FeaturesK <: Symbol,
- WeightK <: Symbol](
- implicit
- i0: LabelledGeneric.Aux[Inputs, InputsRec],
- i1: Length.Aux[InputsRec, Nat._3],
- i2: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
- i3: Witness.Aux[FeaturesK],
- i4: SelectorByValue.Aux[InputsRec, Double, LabelK],
- i5: Witness.Aux[LabelK],
- i6: SelectorByValue.Aux[InputsRec, Float, WeightK],
- i7: Witness.Aux[WeightK]
- ): LinearInputsChecker[Inputs] = {
+ Inputs,
+ InputsRec <: HList,
+ LabelK <: Symbol,
+ FeaturesK <: Symbol,
+ WeightK <: Symbol
+ ](implicit
+ i0: LabelledGeneric.Aux[Inputs, InputsRec],
+ i1: Length.Aux[InputsRec, Nat._3],
+ i2: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
+ i3: Witness.Aux[FeaturesK],
+ i4: SelectorByValue.Aux[InputsRec, Double, LabelK],
+ i5: Witness.Aux[LabelK],
+ i6: SelectorByValue.Aux[InputsRec, Float, WeightK],
+ i7: Witness.Aux[WeightK]
+ ): LinearInputsChecker[Inputs] = {
new LinearInputsChecker[Inputs] {
val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name
val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name
- val weightCol: Option[String] = Some(implicitly[Witness.Aux[WeightK]].value.name)
+ val weightCol: Option[String] = Some(
+ implicitly[Witness.Aux[WeightK]].value.name
+ )
}
}
diff --git a/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala b/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala
index 9a67d5299..7292ff855 100644
--- a/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala
+++ b/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala
@@ -3,24 +3,35 @@ package ml
package internals
import shapeless.labelled.FieldType
-import shapeless.{::, DepFn1, HList, Witness}
+import shapeless.{ ::, DepFn1, HList, Witness }
/**
- * Typeclass supporting record selection by value type (returning the first key whose value is of type `Value`)
- */
-trait SelectorByValue[L <: HList, Value] extends DepFn1[L] with Serializable { type Out <: Symbol }
+ * Typeclass supporting record selection by value type (returning the first key whose value is of type `Value`)
+ */
+trait SelectorByValue[L <: HList, Value] extends DepFn1[L] with Serializable {
+ type Out <: Symbol
+}
object SelectorByValue {
- type Aux[L <: HList, Value, Out0 <: Symbol] = SelectorByValue[L, Value] { type Out = Out0 }
- implicit def select[K <: Symbol, T <: HList, Value](implicit wk: Witness.Aux[K]): Aux[FieldType[K, Value] :: T, Value, K] = {
+ type Aux[L <: HList, Value, Out0 <: Symbol] = SelectorByValue[L, Value] {
+ type Out = Out0
+ }
+
+ implicit def select[K <: Symbol, T <: HList, Value](
+ implicit
+ wk: Witness.Aux[K]
+ ): Aux[FieldType[K, Value] :: T, Value, K] = {
new SelectorByValue[FieldType[K, Value] :: T, Value] {
type Out = K
def apply(l: FieldType[K, Value] :: T): Out = wk.value
}
}
- implicit def recurse[H, T <: HList, Value](implicit st: SelectorByValue[T, Value]): Aux[H :: T, Value, st.Out] = {
+ implicit def recurse[H, T <: HList, Value](
+ implicit
+ st: SelectorByValue[T, Value]
+ ): Aux[H :: T, Value, st.Out] = {
new SelectorByValue[H :: T, Value] {
type Out = st.Out
def apply(l: H :: T): Out = st(l.tail)
diff --git a/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala
index 0fe157654..a7fd7f2c0 100644
--- a/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala
+++ b/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala
@@ -3,14 +3,14 @@ package ml
package internals
import shapeless.ops.hlist.Length
-import shapeless.{HList, LabelledGeneric, Nat, Witness}
+import shapeless.{ HList, LabelledGeneric, Nat, Witness }
import org.apache.spark.ml.linalg._
import scala.annotation.implicitNotFound
/**
- * Can be used for all tree-based ML algorithm (decision tree, random forest, gradient-boosted trees)
- */
+ * Can be used for all tree-based ML algorithm (decision tree, random forest, gradient-boosted trees)
+ */
@implicitNotFound(
msg = "Cannot prove that ${Inputs} is a valid input type. " +
"Input type must only contain a field of type Double (the label) and a field of type " +
@@ -24,18 +24,18 @@ trait TreesInputsChecker[Inputs] {
object TreesInputsChecker {
implicit def checkTreesInputs[
- Inputs,
- InputsRec <: HList,
- LabelK <: Symbol,
- FeaturesK <: Symbol](
- implicit
- i0: LabelledGeneric.Aux[Inputs, InputsRec],
- i1: Length.Aux[InputsRec, Nat._2],
- i2: SelectorByValue.Aux[InputsRec, Double, LabelK],
- i3: Witness.Aux[LabelK],
- i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
- i5: Witness.Aux[FeaturesK]
- ): TreesInputsChecker[Inputs] = {
+ Inputs,
+ InputsRec <: HList,
+ LabelK <: Symbol,
+ FeaturesK <: Symbol
+ ](implicit
+ i0: LabelledGeneric.Aux[Inputs, InputsRec],
+ i1: Length.Aux[InputsRec, Nat._2],
+ i2: SelectorByValue.Aux[InputsRec, Double, LabelK],
+ i3: Witness.Aux[LabelK],
+ i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
+ i5: Witness.Aux[FeaturesK]
+ ): TreesInputsChecker[Inputs] = {
new TreesInputsChecker[Inputs] {
val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name
val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name
diff --git a/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala
index 56dfc9a57..51b8e1306 100644
--- a/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala
+++ b/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala
@@ -3,13 +3,13 @@ package ml
package internals
import shapeless.ops.hlist.Length
-import shapeless.{HList, LabelledGeneric, Nat, Witness}
+import shapeless.{ HList, LabelledGeneric, Nat, Witness }
import scala.annotation.implicitNotFound
/**
- * Can be used for all unary transformers (i.e almost all of them)
- */
+ * Can be used for all unary transformers (i.e almost all of them)
+ */
@implicitNotFound(
msg = "Cannot prove that ${Inputs} is a valid input type. Input type must have only one field of type ${Expected}"
)
@@ -19,15 +19,19 @@ trait UnaryInputsChecker[Inputs, Expected] {
object UnaryInputsChecker {
- implicit def checkUnaryInputs[Inputs, Expected, InputsRec <: HList, InputK <: Symbol](
- implicit
- i0: LabelledGeneric.Aux[Inputs, InputsRec],
- i1: Length.Aux[InputsRec, Nat._1],
- i2: SelectorByValue.Aux[InputsRec, Expected, InputK],
- i3: Witness.Aux[InputK]
- ): UnaryInputsChecker[Inputs, Expected] = new UnaryInputsChecker[Inputs, Expected] {
- val inputCol: String = implicitly[Witness.Aux[InputK]].value.name
- }
+ implicit def checkUnaryInputs[
+ Inputs,
+ Expected,
+ InputsRec <: HList,
+ InputK <: Symbol
+ ](implicit
+ i0: LabelledGeneric.Aux[Inputs, InputsRec],
+ i1: Length.Aux[InputsRec, Nat._1],
+ i2: SelectorByValue.Aux[InputsRec, Expected, InputK],
+ i3: Witness.Aux[InputK]
+ ): UnaryInputsChecker[Inputs, Expected] =
+ new UnaryInputsChecker[Inputs, Expected] {
+ val inputCol: String = implicitly[Witness.Aux[InputK]].value.name
+ }
}
-
diff --git a/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala
index e993d9a55..13f199de0 100644
--- a/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala
+++ b/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala
@@ -3,7 +3,7 @@ package ml
package internals
import shapeless.ops.hlist.Length
-import shapeless.{HList, LabelledGeneric, Nat, Witness}
+import shapeless.{ HList, LabelledGeneric, Nat, Witness }
import scala.annotation.implicitNotFound
import org.apache.spark.ml.linalg.Vector
@@ -18,15 +18,19 @@ trait VectorInputsChecker[Inputs] {
}
object VectorInputsChecker {
- implicit def checkVectorInput[Inputs, InputsRec <: HList, FeaturesK <: Symbol](
- implicit
+
+ implicit def checkVectorInput[
+ Inputs,
+ InputsRec <: HList,
+ FeaturesK <: Symbol
+ ](implicit
i0: LabelledGeneric.Aux[Inputs, InputsRec],
i1: Length.Aux[InputsRec, Nat._1],
i2: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
i3: Witness.Aux[FeaturesK]
): VectorInputsChecker[Inputs] = {
- new VectorInputsChecker[Inputs] {
- val featuresCol: String = i3.value.name
- }
+ new VectorInputsChecker[Inputs] {
+ val featuresCol: String = i3.value.name
}
+ }
}
diff --git a/ml/src/main/scala/frameless/ml/package.scala b/ml/src/main/scala/frameless/ml/package.scala
index d1c306158..8478c1332 100644
--- a/ml/src/main/scala/frameless/ml/package.scala
+++ b/ml/src/main/scala/frameless/ml/package.scala
@@ -2,12 +2,14 @@ package frameless
import org.apache.spark.sql.FramelessInternals.UserDefinedType
import org.apache.spark.ml.FramelessInternals
-import org.apache.spark.ml.linalg.{Matrix, Vector}
+import org.apache.spark.ml.linalg.{ Matrix, Vector }
package object ml {
- implicit val mlVectorUdt: UserDefinedType[Vector] = FramelessInternals.vectorUdt
+ implicit val mlVectorUdt: UserDefinedType[Vector] =
+ FramelessInternals.vectorUdt
- implicit val mlMatrixUdt: UserDefinedType[Matrix] = FramelessInternals.matrixUdt
+ implicit val mlMatrixUdt: UserDefinedType[Matrix] =
+ FramelessInternals.matrixUdt
}
diff --git a/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala b/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala
index b3c023735..c1bca8e58 100644
--- a/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala
+++ b/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala
@@ -4,14 +4,14 @@ package params
package kmeans
/**
- * Param for the initialization algorithm.
- * This can be either "random" to choose random points as
- * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
- * (Bahmani et al., Scalable K-Means++, VLDB 2012).
- * Default: k-means||.
- */
+ * Param for the initialization algorithm.
+ * This can be either "random" to choose random points as
+ * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
+ * (Bahmani et al., Scalable K-Means++, VLDB 2012).
+ * Default: k-means||.
+ */
-sealed abstract class KMeansInitMode private[ml](val sparkValue: String)
+sealed abstract class KMeansInitMode private[ml] (val sparkValue: String)
object KMeansInitMode {
case object Random extends KMeansInitMode("random")
diff --git a/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala b/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala
index 4b9ca6d4e..0dbd896c4 100644
--- a/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala
+++ b/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala
@@ -2,15 +2,17 @@ package frameless
package ml
package params
package linears
+
/**
- * SquaredError measures the average of the squares of the errors—that is,
- * the average squared difference between the estimated values and what is estimated.
- *
- * Huber Loss loss function less sensitive to outliers in data than the
- * squared error loss
- */
-sealed abstract class LossStrategy private[ml](val sparkValue: String)
+ * SquaredError measures the average of the squares of the errors—that is,
+ * the average squared difference between the estimated values and what is estimated.
+ *
+ * Huber Loss loss function less sensitive to outliers in data than the
+ * squared error loss
+ */
+sealed abstract class LossStrategy private[ml] (val sparkValue: String)
+
object LossStrategy {
case object SquaredError extends LossStrategy("squaredError")
- case object Huber extends LossStrategy("huber")
+ case object Huber extends LossStrategy("huber")
}
diff --git a/ml/src/main/scala/frameless/ml/params/linears/Solver.scala b/ml/src/main/scala/frameless/ml/params/linears/Solver.scala
index 277e06e7a..7b14b8b95 100644
--- a/ml/src/main/scala/frameless/ml/params/linears/Solver.scala
+++ b/ml/src/main/scala/frameless/ml/params/linears/Solver.scala
@@ -4,22 +4,22 @@ package params
package linears
/**
- * solver algorithm used for optimization.
- * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
- * optimization method.
- * - "normal" denotes using Normal Equation as an analytical solution to the linear regression
- * problem. This solver is limited to `LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER`.
- * - "auto" (default) means that the solver algorithm is selected automatically.
- * The Normal Equations solver will be used when possible, but this will automatically fall
- * back to iterative optimization methods when needed.
- *
- * spark
- */
+ * solver algorithm used for optimization.
+ * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
+ * optimization method.
+ * - "normal" denotes using Normal Equation as an analytical solution to the linear regression
+ * problem. This solver is limited to `LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER`.
+ * - "auto" (default) means that the solver algorithm is selected automatically.
+ * The Normal Equations solver will be used when possible, but this will automatically fall
+ * back to iterative optimization methods when needed.
+ *
+ * spark
+ */
+
+sealed abstract class Solver private[ml] (val sparkValue: String)
-sealed abstract class Solver private[ml](val sparkValue: String)
object Solver {
- case object LBFGS extends Solver("l-bfgs")
- case object Auto extends Solver("auto")
- case object Normal extends Solver("normal")
+ case object LBFGS extends Solver("l-bfgs")
+ case object Auto extends Solver("auto")
+ case object Normal extends Solver("normal")
}
-
diff --git a/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala b/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala
index f2167f983..67f2ddc8d 100644
--- a/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala
+++ b/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala
@@ -2,32 +2,34 @@ package frameless
package ml
package params
package trees
+
/**
- * The number of features to consider for splits at each tree node.
- * Supported options:
- * - Auto: Choose automatically for task:
- * If numTrees == 1, set to All
- * If numTrees > 1 (forest), set to Sqrt for classification and
- * to OneThird for regression.
- * - All: use all features
- * - OneThird: use 1/3 of the features
- * - Sqrt: use sqrt(number of features)
- * - Log2: use log2(number of features)
- * - Ratio: use (ratio * number of features) features
- * - NumberOfFeatures: use numberOfFeatures features.
- * (default = Auto)
- *
- * These various settings are based on the following references:
- * - log2: tested in Breiman (2001)
- * - sqrt: recommended by Breiman manual for random forests
- * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
- * package.
- *
- * @see Breiman (2001)
- * @see
- * Breiman manual for random forests
- */
-sealed abstract class FeatureSubsetStrategy private[ml](val sparkValue: String)
+ * The number of features to consider for splits at each tree node.
+ * Supported options:
+ * - Auto: Choose automatically for task:
+ * If numTrees == 1, set to All
+ * If numTrees > 1 (forest), set to Sqrt for classification and
+ * to OneThird for regression.
+ * - All: use all features
+ * - OneThird: use 1/3 of the features
+ * - Sqrt: use sqrt(number of features)
+ * - Log2: use log2(number of features)
+ * - Ratio: use (ratio * number of features) features
+ * - NumberOfFeatures: use numberOfFeatures features.
+ * (default = Auto)
+ *
+ * These various settings are based on the following references:
+ * - log2: tested in Breiman (2001)
+ * - sqrt: recommended by Breiman manual for random forests
+ * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+ * package.
+ *
+ * @see Breiman (2001)
+ * @see
+ * Breiman manual for random forests
+ */
+sealed abstract class FeatureSubsetStrategy private[ml] (val sparkValue: String)
+
object FeatureSubsetStrategy {
case object Auto extends FeatureSubsetStrategy("auto")
case object All extends FeatureSubsetStrategy("all")
@@ -35,5 +37,7 @@ object FeatureSubsetStrategy {
case object Sqrt extends FeatureSubsetStrategy("sqrt")
case object Log2 extends FeatureSubsetStrategy("log2")
case class Ratio(value: Double) extends FeatureSubsetStrategy(value.toString)
- case class NumberOfFeatures(value: Int) extends FeatureSubsetStrategy(value.toString)
-}
\ No newline at end of file
+
+ case class NumberOfFeatures(value: Int)
+ extends FeatureSubsetStrategy(value.toString)
+}
diff --git a/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala b/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala
index 3b3208623..a1eeb169c 100644
--- a/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala
+++ b/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala
@@ -3,38 +3,66 @@ package ml
package regression
import frameless.ml.internals.LinearInputsChecker
-import frameless.ml.params.linears.{LossStrategy, Solver}
-import frameless.ml.{AppendTransformer, TypedEstimator}
-import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
+import frameless.ml.params.linears.{ LossStrategy, Solver }
+import frameless.ml.{ AppendTransformer, TypedEstimator }
+import org.apache.spark.ml.regression.{
+ LinearRegression,
+ LinearRegressionModel
+}
/**
- * Linear Regression linear approach to modelling the relationship
- * between a scalar response (or dependent variable) and one or more explanatory variables
- */
-final class TypedLinearRegression [Inputs] private[ml](
- lr: LinearRegression,
- labelCol: String,
- featuresCol: String,
- weightCol: Option[String]
-) extends TypedEstimator[Inputs, TypedLinearRegression.Outputs, LinearRegressionModel] {
-
- val estimatorWithoutWeight : LinearRegression = lr
+ * Linear Regression linear approach to modelling the relationship
+ * between a scalar response (or dependent variable) and one or more explanatory variables
+ */
+final class TypedLinearRegression[Inputs] private[ml] (
+ lr: LinearRegression,
+ labelCol: String,
+ featuresCol: String,
+ weightCol: Option[String])
+ extends TypedEstimator[
+ Inputs,
+ TypedLinearRegression.Outputs,
+ LinearRegressionModel
+ ] {
+
+ val estimatorWithoutWeight: LinearRegression = lr
.setLabelCol(labelCol)
.setFeaturesCol(featuresCol)
.setPredictionCol(AppendTransformer.tempColumnName)
- val estimator = if (weightCol.isDefined) estimatorWithoutWeight.setWeightCol(weightCol.get) else estimatorWithoutWeight
+ val estimator =
+ if (weightCol.isDefined) estimatorWithoutWeight.setWeightCol(weightCol.get)
+ else estimatorWithoutWeight
+
+ def setRegParam(value: Double): TypedLinearRegression[Inputs] =
+ copy(lr.setRegParam(value))
+
+ def setFitIntercept(value: Boolean): TypedLinearRegression[Inputs] =
+ copy(lr.setFitIntercept(value))
+
+ def setStandardization(value: Boolean): TypedLinearRegression[Inputs] =
+ copy(lr.setStandardization(value))
+
+ def setElasticNetParam(value: Double): TypedLinearRegression[Inputs] =
+ copy(lr.setElasticNetParam(value))
+
+ def setMaxIter(value: Int): TypedLinearRegression[Inputs] =
+ copy(lr.setMaxIter(value))
- def setRegParam(value: Double): TypedLinearRegression[Inputs] = copy(lr.setRegParam(value))
- def setFitIntercept(value: Boolean): TypedLinearRegression[Inputs] = copy(lr.setFitIntercept(value))
- def setStandardization(value: Boolean): TypedLinearRegression[Inputs] = copy(lr.setStandardization(value))
- def setElasticNetParam(value: Double): TypedLinearRegression[Inputs] = copy(lr.setElasticNetParam(value))
- def setMaxIter(value: Int): TypedLinearRegression[Inputs] = copy(lr.setMaxIter(value))
- def setTol(value: Double): TypedLinearRegression[Inputs] = copy(lr.setTol(value))
- def setSolver(value: Solver): TypedLinearRegression[Inputs] = copy(lr.setSolver(value.sparkValue))
- def setAggregationDepth(value: Int): TypedLinearRegression[Inputs] = copy(lr.setAggregationDepth(value))
- def setLoss(value: LossStrategy): TypedLinearRegression[Inputs] = copy(lr.setLoss(value.sparkValue))
- def setEpsilon(value: Double): TypedLinearRegression[Inputs] = copy(lr.setEpsilon(value))
+ def setTol(value: Double): TypedLinearRegression[Inputs] =
+ copy(lr.setTol(value))
+
+ def setSolver(value: Solver): TypedLinearRegression[Inputs] =
+ copy(lr.setSolver(value.sparkValue))
+
+ def setAggregationDepth(value: Int): TypedLinearRegression[Inputs] =
+ copy(lr.setAggregationDepth(value))
+
+ def setLoss(value: LossStrategy): TypedLinearRegression[Inputs] =
+ copy(lr.setLoss(value.sparkValue))
+
+ def setEpsilon(value: Double): TypedLinearRegression[Inputs] =
+ copy(lr.setEpsilon(value))
private def copy(newLr: LinearRegression): TypedLinearRegression[Inputs] =
new TypedLinearRegression[Inputs](newLr, labelCol, featuresCol, weightCol)
@@ -45,8 +73,15 @@ object TypedLinearRegression {
case class Outputs(prediction: Double)
case class Weight(weight: Double)
-
- def apply[Inputs](implicit inputsChecker: LinearInputsChecker[Inputs]): TypedLinearRegression[Inputs] = {
- new TypedLinearRegression(new LinearRegression(), inputsChecker.labelCol, inputsChecker.featuresCol, inputsChecker.weightCol)
+ def apply[Inputs](
+ implicit
+ inputsChecker: LinearInputsChecker[Inputs]
+ ): TypedLinearRegression[Inputs] = {
+ new TypedLinearRegression(
+ new LinearRegression(),
+ inputsChecker.labelCol,
+ inputsChecker.featuresCol,
+ inputsChecker.weightCol
+ )
}
-}
\ No newline at end of file
+}
diff --git a/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala b/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala
index 69c1ad68c..89586671b 100644
--- a/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala
+++ b/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala
@@ -4,44 +4,74 @@ package regression
import frameless.ml.internals.TreesInputsChecker
import frameless.ml.params.trees.FeatureSubsetStrategy
-import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
+import org.apache.spark.ml.regression.{
+ RandomForestRegressionModel,
+ RandomForestRegressor
+}
/**
- * Random Forest
- * learning algorithm for regression.
- * It supports both continuous and categorical features.
- */
-final class TypedRandomForestRegressor[Inputs] private[ml](
- rf: RandomForestRegressor,
- labelCol: String,
- featuresCol: String
-) extends TypedEstimator[Inputs, TypedRandomForestRegressor.Outputs, RandomForestRegressionModel] {
+ * Random Forest
+ * learning algorithm for regression.
+ * It supports both continuous and categorical features.
+ */
+final class TypedRandomForestRegressor[Inputs] private[ml] (
+ rf: RandomForestRegressor,
+ labelCol: String,
+ featuresCol: String)
+ extends TypedEstimator[
+ Inputs,
+ TypedRandomForestRegressor.Outputs,
+ RandomForestRegressionModel
+ ] {
val estimator: RandomForestRegressor =
- rf
- .setLabelCol(labelCol)
+ rf.setLabelCol(labelCol)
.setFeaturesCol(featuresCol)
.setPredictionCol(AppendTransformer.tempColumnName)
- def setNumTrees(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setNumTrees(value))
- def setMaxDepth(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMaxDepth(value))
- def setMinInfoGain(value: Double): TypedRandomForestRegressor[Inputs] = copy(rf.setMinInfoGain(value))
- def setMinInstancesPerNode(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMinInstancesPerNode(value))
- def setMaxMemoryInMB(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMaxMemoryInMB(value))
- def setSubsamplingRate(value: Double): TypedRandomForestRegressor[Inputs] = copy(rf.setSubsamplingRate(value))
- def setFeatureSubsetStrategy(value: FeatureSubsetStrategy): TypedRandomForestRegressor[Inputs] =
+ def setNumTrees(value: Int): TypedRandomForestRegressor[Inputs] =
+ copy(rf.setNumTrees(value))
+
+ def setMaxDepth(value: Int): TypedRandomForestRegressor[Inputs] =
+ copy(rf.setMaxDepth(value))
+
+ def setMinInfoGain(value: Double): TypedRandomForestRegressor[Inputs] =
+ copy(rf.setMinInfoGain(value))
+
+ def setMinInstancesPerNode(value: Int): TypedRandomForestRegressor[Inputs] =
+ copy(rf.setMinInstancesPerNode(value))
+
+ def setMaxMemoryInMB(value: Int): TypedRandomForestRegressor[Inputs] =
+ copy(rf.setMaxMemoryInMB(value))
+
+ def setSubsamplingRate(value: Double): TypedRandomForestRegressor[Inputs] =
+ copy(rf.setSubsamplingRate(value))
+
+ def setFeatureSubsetStrategy(
+ value: FeatureSubsetStrategy
+ ): TypedRandomForestRegressor[Inputs] =
copy(rf.setFeatureSubsetStrategy(value.sparkValue))
- def setMaxBins(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMaxBins(value))
- private def copy(newRf: RandomForestRegressor): TypedRandomForestRegressor[Inputs] =
+ def setMaxBins(value: Int): TypedRandomForestRegressor[Inputs] =
+ copy(rf.setMaxBins(value))
+
+ private def copy(
+ newRf: RandomForestRegressor
+ ): TypedRandomForestRegressor[Inputs] =
new TypedRandomForestRegressor[Inputs](newRf, labelCol, featuresCol)
}
object TypedRandomForestRegressor {
case class Outputs(prediction: Double)
- def apply[Inputs](implicit inputsChecker: TreesInputsChecker[Inputs])
- : TypedRandomForestRegressor[Inputs] = {
- new TypedRandomForestRegressor(new RandomForestRegressor(), inputsChecker.labelCol, inputsChecker.featuresCol)
+ def apply[Inputs](
+ implicit
+ inputsChecker: TreesInputsChecker[Inputs]
+ ): TypedRandomForestRegressor[Inputs] = {
+ new TypedRandomForestRegressor(
+ new RandomForestRegressor(),
+ inputsChecker.labelCol,
+ inputsChecker.featuresCol
+ )
}
-}
\ No newline at end of file
+}
diff --git a/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala b/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala
index bec43cd11..6f697ea28 100644
--- a/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala
+++ b/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala
@@ -1,6 +1,6 @@
package org.apache.spark.ml
-import org.apache.spark.ml.linalg.{MatrixUDT, VectorUDT}
+import org.apache.spark.ml.linalg.{ MatrixUDT, VectorUDT }
object FramelessInternals {
diff --git a/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala b/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala
index de8fcab56..60104fbd2 100644
--- a/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala
+++ b/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala
@@ -6,7 +6,12 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatestplus.scalacheck.Checkers
import org.scalatest.funsuite.AnyFunSuite
-class FramelessMlSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll with SparkTesting {
+class FramelessMlSuite
+ extends AnyFunSuite
+ with Checkers
+ with BeforeAndAfterAll
+ with SparkTesting {
+
// Limit size of generated collections and number of checks because Travis
implicit override val generatorDrivenConfig =
PropertyCheckConfiguration(sizeRange = PosZInt(10), minSize = PosZInt(10))
diff --git a/ml/src/test/scala/frameless/ml/Generators.scala b/ml/src/test/scala/frameless/ml/Generators.scala
index f7dde986c..48c0e4bc0 100644
--- a/ml/src/test/scala/frameless/ml/Generators.scala
+++ b/ml/src/test/scala/frameless/ml/Generators.scala
@@ -1,15 +1,18 @@
package frameless
package ml
-import frameless.ml.params.linears.{LossStrategy, Solver}
+import frameless.ml.params.linears.{ LossStrategy, Solver }
import frameless.ml.params.trees.FeatureSubsetStrategy
-import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
-import org.scalacheck.{Arbitrary, Gen}
+import org.apache.spark.ml.linalg.{ Matrices, Matrix, Vector, Vectors }
+import org.scalacheck.{ Arbitrary, Gen }
object Generators {
implicit val arbVector: Arbitrary[Vector] = Arbitrary {
- val genDenseVector = Gen.listOf(arbDouble.arbitrary).suchThat(_.nonEmpty).map(doubles => Vectors.dense(doubles.toArray))
+ val genDenseVector = Gen
+ .listOf(arbDouble.arbitrary)
+ .suchThat(_.nonEmpty)
+ .map(doubles => Vectors.dense(doubles.toArray))
val genSparseVector = genDenseVector.map(_.toSparse)
Gen.oneOf(genDenseVector, genSparseVector)
@@ -21,29 +24,34 @@ object Generators {
nbRows <- Gen.choose(0, size)
nbCols <- Gen.choose(1, size)
matrix <- {
- Gen.listOfN(nbRows * nbCols, arbDouble.arbitrary)
+ Gen
+ .listOfN(nbRows * nbCols, arbDouble.arbitrary)
.map(values => Matrices.dense(nbRows, nbCols, values.toArray))
}
} yield matrix
}
}
- implicit val arbTreesFeaturesSubsetStrategy: Arbitrary[FeatureSubsetStrategy] = Arbitrary {
- val genRatio = Gen.choose(0D, 1D).suchThat(_ > 0D).map(FeatureSubsetStrategy.Ratio)
- val genNumberOfFeatures = Gen.choose(1, Int.MaxValue).map(FeatureSubsetStrategy.NumberOfFeatures)
-
- Gen.oneOf(Gen.const(FeatureSubsetStrategy.All),
- Gen.const(FeatureSubsetStrategy.All),
- Gen.const(FeatureSubsetStrategy.Log2),
- Gen.const(FeatureSubsetStrategy.OneThird),
- Gen.const(FeatureSubsetStrategy.Sqrt),
- genRatio,
- genNumberOfFeatures
- )
- }
+ implicit val arbTreesFeaturesSubsetStrategy: Arbitrary[FeatureSubsetStrategy] =
+ Arbitrary {
+ val genRatio =
+ Gen.choose(0D, 1D).suchThat(_ > 0D).map(FeatureSubsetStrategy.Ratio)
+ val genNumberOfFeatures =
+ Gen.choose(1, Int.MaxValue).map(FeatureSubsetStrategy.NumberOfFeatures)
+
+ Gen.oneOf(
+ Gen.const(FeatureSubsetStrategy.All),
+ Gen.const(FeatureSubsetStrategy.All),
+ Gen.const(FeatureSubsetStrategy.Log2),
+ Gen.const(FeatureSubsetStrategy.OneThird),
+ Gen.const(FeatureSubsetStrategy.Sqrt),
+ genRatio,
+ genNumberOfFeatures
+ )
+ }
implicit val arbLossStrategy: Arbitrary[LossStrategy] = Arbitrary {
- Gen.const(LossStrategy.SquaredError)
+ Gen.const(LossStrategy.SquaredError)
}
implicit val arbSolver: Arbitrary[Solver] = Arbitrary {
diff --git a/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala b/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala
index 0f7e37439..a54e09e42 100644
--- a/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala
+++ b/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala
@@ -23,12 +23,15 @@ class TypedEncoderInstancesTests extends FramelessMlSuite {
check(prop)
}
- test("Vector is encoded as VectorUDT and thus can be run in a Spark ML model") {
+ test(
+ "Vector is encoded as VectorUDT and thus can be run in a Spark ML model"
+ ) {
case class Input(features: Vector, label: Double)
val prop = forAll { trainingData: Matrix =>
(trainingData.numRows >= 1) ==> {
- val inputs = trainingData.rowIter.toVector.map(vector => Input(vector, 0D))
+ val inputs =
+ trainingData.rowIter.toVector.map(vector => Input(vector, 0D))
val inputsDS = TypedDataset.create(inputs)
val model = new DecisionTreeRegressor()
@@ -39,7 +42,8 @@ class TypedEncoderInstancesTests extends FramelessMlSuite {
val randomInput = inputs(Random.nextInt(inputs.length))
val randomInputDS = TypedDataset.create(Seq(randomInput))
- val prediction = trainedModel.transform(randomInputDS.dataset)
+ val prediction = trainedModel
+ .transform(randomInputDS.dataset)
.select("prediction")
.head()
.getAs[Double](0)
diff --git a/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala b/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala
index d98c3b2bf..9fbe2bebc 100644
--- a/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala
+++ b/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala
@@ -2,7 +2,11 @@ package frameless
package ml
package classification
-import frameless.ml.feature.{TypedIndexToString, TypedStringIndexer, TypedVectorAssembler}
+import frameless.ml.feature.{
+ TypedIndexToString,
+ TypedStringIndexer,
+ TypedVectorAssembler
+}
import org.apache.spark.ml.linalg.Vector
import org.scalatest.matchers.must.Matchers
@@ -18,15 +22,26 @@ class ClassificationIntegrationTests extends FramelessMlSuite with Matchers {
case class Features(field1: Double, field2: Int)
val vectorAssembler = TypedVectorAssembler[Features]
- case class DataWithFeatures(field1: Double, field2: Int, field3: String, features: Vector)
- val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]()
+ case class DataWithFeatures(
+ field1: Double,
+ field2: Int,
+ field3: String,
+ features: Vector)
+ val dataWithFeatures =
+ vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]()
case class StringIndexerInput(field3: String)
val indexer = TypedStringIndexer[StringIndexerInput]
val indexerModel = indexer.fit(dataWithFeatures).run()
- case class IndexedDataWithFeatures(field1: Double, field2: Int, field3: String, features: Vector, indexedField3: Double)
- val indexedData = indexerModel.transform(dataWithFeatures).as[IndexedDataWithFeatures]()
+ case class IndexedDataWithFeatures(
+ field1: Double,
+ field2: Int,
+ field3: String,
+ features: Vector,
+ indexedField3: Double)
+ val indexedData =
+ indexerModel.transform(dataWithFeatures).as[IndexedDataWithFeatures]()
case class RFInputs(indexedField3: Double, features: Vector)
val rf = TypedRandomForestClassifier[RFInputs]
@@ -35,38 +50,47 @@ class ClassificationIntegrationTests extends FramelessMlSuite with Matchers {
// Prediction
- val testData = TypedDataset.create(Seq(
- Data(0D, 10, "foo")
- ))
- val testDataWithFeatures = vectorAssembler.transform(testData).as[DataWithFeatures]()
- val indexedTestData = indexerModel.transform(testDataWithFeatures).as[IndexedDataWithFeatures]()
+ val testData = TypedDataset.create(
+ Seq(
+ Data(0D, 10, "foo")
+ )
+ )
+ val testDataWithFeatures =
+ vectorAssembler.transform(testData).as[DataWithFeatures]()
+ val indexedTestData =
+ indexerModel.transform(testDataWithFeatures).as[IndexedDataWithFeatures]()
case class PredictionInputs(features: Vector, indexedField3: Double)
val testInput = indexedTestData.project[PredictionInputs]
case class PredictionResultIndexed(
- features: Vector,
- indexedField3: Double,
- rawPrediction: Vector,
- probability: Vector,
- predictedField3Indexed: Double
- )
+ features: Vector,
+ indexedField3: Double,
+ rawPrediction: Vector,
+ probability: Vector,
+ predictedField3Indexed: Double)
val predictionDs = model.transform(testInput).as[PredictionResultIndexed]()
case class IndexToStringInput(predictedField3Indexed: Double)
- val indexToString = TypedIndexToString[IndexToStringInput](indexerModel.transformer.labelsArray.flatten)
-
- case class PredictionResult(
- features: Vector,
- indexedField3: Double,
- rawPrediction: Vector,
- probability: Vector,
- predictedField3Indexed: Double,
- predictedField3: String
+ val indexToString = TypedIndexToString[IndexToStringInput](
+ indexerModel.transformer.labelsArray.flatten
)
- val stringPredictionDs = indexToString.transform(predictionDs).as[PredictionResult]()
- val prediction = stringPredictionDs.select(stringPredictionDs.col('predictedField3)).collect().run().toList
+ case class PredictionResult(
+ features: Vector,
+ indexedField3: Double,
+ rawPrediction: Vector,
+ probability: Vector,
+ predictedField3Indexed: Double,
+ predictedField3: String)
+ val stringPredictionDs =
+ indexToString.transform(predictionDs).as[PredictionResult]()
+
+ val prediction = stringPredictionDs
+ .select(stringPredictionDs.col('predictedField3))
+ .collect()
+ .run()
+ .toList
prediction mustEqual List("foo")
}
diff --git a/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala b/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala
index ab03f1aad..f5571d17a 100644
--- a/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala
+++ b/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala
@@ -5,15 +5,21 @@ package classification
import shapeless.test.illTyped
import org.apache.spark.ml.linalg._
import frameless.ml.params.trees.FeatureSubsetStrategy
-import org.scalacheck.{Arbitrary, Gen}
+import org.scalacheck.{ Arbitrary, Gen }
import org.scalacheck.Prop._
import org.scalatest.matchers.must.Matchers
class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers {
+
implicit val arbDouble: Arbitrary[Double] =
- Arbitrary(Gen.choose(1, 99).map(_.toDouble)) // num classes must be between 0 and 100 for the test
+ Arbitrary(
+ Gen.choose(1, 99).map(_.toDouble)
+ ) // num classes must be between 0 and 100 for the test
+
implicit val arbVectorNonEmpty: Arbitrary[Vector] =
- Arbitrary(Generators.arbVector.arbitrary suchThat (_.size > 0)) // vector must not be empty for RandomForestClassifier
+ Arbitrary(
+ Generators.arbVector.arbitrary suchThat (_.size > 0)
+ ) // vector must not be empty for RandomForestClassifier
import Generators.arbTreesFeaturesSubsetStrategy
test("fit() returns a correct TypedTransformer") {
@@ -21,7 +27,8 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers {
val rf = TypedRandomForestClassifier[X2[Double, Vector]]
val ds = TypedDataset.create(Seq(x2))
val model = rf.fit(ds).run()
- val pDs = model.transform(ds).as[X5[Double, Vector, Vector, Vector, Double]]()
+ val pDs =
+ model.transform(ds).as[X5[Double, Vector, Vector, Vector, Double]]()
pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b)
}
@@ -30,19 +37,26 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers {
val rf = TypedRandomForestClassifier[X2[Vector, Double]]
val ds = TypedDataset.create(Seq(x2))
val model = rf.fit(ds).run()
- val pDs = model.transform(ds).as[X5[Vector, Double, Vector, Vector, Double]]()
+ val pDs =
+ model.transform(ds).as[X5[Vector, Double, Vector, Vector, Double]]()
pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b)
}
- def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] =>
- val rf = TypedRandomForestClassifier[X2[Vector, Double]]
- val ds = TypedDataset.create(Seq(x3))
- val model = rf.fit(ds).run()
- val pDs = model.transform(ds).as[X6[Vector, Double, A, Vector, Vector, Double]]()
+ def prop3[A: TypedEncoder: Arbitrary] =
+ forAll { x3: X3[Vector, Double, A] =>
+ val rf = TypedRandomForestClassifier[X2[Vector, Double]]
+ val ds = TypedDataset.create(Seq(x3))
+ val model = rf.fit(ds).run()
+ val pDs = model
+ .transform(ds)
+ .as[X6[Vector, Double, A, Vector, Vector, Double]]()
- pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect().run() == Seq((x3.a, x3.b, x3.c))
- }
+ pDs
+ .select(pDs.col('a), pDs.col('b), pDs.col('c))
+ .collect()
+ .run() == Seq((x3.a, x3.b, x3.c))
+ }
check(prop)
check(prop2)
@@ -66,13 +80,13 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers {
val model = rf.fit(ds).run()
model.transformer.getNumTrees == 10 &&
- model.transformer.getMaxBins == 100 &&
- model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue &&
- model.transformer.getMaxDepth == 10 &&
- model.transformer.getMaxMemoryInMB == 100 &&
- model.transformer.getMinInfoGain == 0.1D &&
- model.transformer.getMinInstancesPerNode == 2 &&
- model.transformer.getSubsamplingRate == 0.9D
+ model.transformer.getMaxBins == 100 &&
+ model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue &&
+ model.transformer.getMaxDepth == 10 &&
+ model.transformer.getMaxMemoryInMB == 100 &&
+ model.transformer.getMinInfoGain == 0.1D &&
+ model.transformer.getMinInstancesPerNode == 2 &&
+ model.transformer.getSubsamplingRate == 0.9D
}
check(prop)
@@ -86,4 +100,4 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers {
illTyped("TypedRandomForestClassifier.create[X2[Vector, String]]()")
}
-}
\ No newline at end of file
+}
diff --git a/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala b/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala
index 976df39b2..c19c42ce3 100644
--- a/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala
+++ b/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala
@@ -2,7 +2,7 @@ package frameless
package ml
package clustering
-import frameless.{TypedDataset, TypedEncoder, X1, X2, X3}
+import frameless.{ TypedDataset, TypedEncoder, X1, X2, X3 }
import frameless.ml.classification.TypedBisectingKMeans
import org.scalacheck.Arbitrary
import org.apache.spark.ml.linalg._
@@ -11,6 +11,7 @@ import frameless.ml._
import org.scalatest.matchers.must.Matchers
class BisectingKMeansTests extends FramelessMlSuite with Matchers {
+
implicit val arbVector: Arbitrary[Vector] =
Arbitrary(Generators.arbVector.arbitrary)
@@ -24,7 +25,7 @@ class BisectingKMeansTests extends FramelessMlSuite with Matchers {
pDs.select(pDs.col('a)).collect().run().toList == Seq(x1.a)
}
- def prop3[A: TypedEncoder : Arbitrary] = forAll { x2: X2[Vector, A] =>
+ def prop3[A: TypedEncoder: Arbitrary] = forAll { x2: X2[Vector, A] =>
val km = TypedBisectingKMeans[X1[Vector]]()
val ds = TypedDataset.create(Seq(x2))
val model = km.fit(ds).run()
@@ -44,12 +45,12 @@ class BisectingKMeansTests extends FramelessMlSuite with Matchers {
.setMinDivisibleClusterSize(1)
.setSeed(123332)
- val ds = TypedDataset.create(Seq(X2(Vectors.dense(Array(0D)),0)))
+ val ds = TypedDataset.create(Seq(X2(Vectors.dense(Array(0D)), 0)))
val model = rf.fit(ds).run()
- model.transformer.getK == 10 &&
- model.transformer.getMaxIter == 10 &&
- model.transformer.getMinDivisibleClusterSize == 1 &&
- model.transformer.getSeed == 123332
+ model.transformer.getK == 10 &&
+ model.transformer.getMaxIter == 10 &&
+ model.transformer.getMinDivisibleClusterSize == 1 &&
+ model.transformer.getSeed == 123332
}
}
diff --git a/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala b/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala
index 398a0963d..a59cc03e3 100644
--- a/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala
+++ b/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala
@@ -3,7 +3,7 @@ package ml
package clustering
import frameless.ml.FramelessMlSuite
-import frameless.ml.classification.{TypedBisectingKMeans, TypedKMeans}
+import frameless.ml.classification.{ TypedBisectingKMeans, TypedKMeans }
import org.apache.spark.ml.linalg.Vector
import frameless._
import frameless.ml._
@@ -14,11 +14,13 @@ class ClusteringIntegrationTests extends FramelessMlSuite with Matchers {
test("predict field2 from field1 using a K-means clustering") {
// Training
- val trainingDataDs = TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D,0))
+ val trainingDataDs =
+ TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D, 0))
val vectorAssembler = TypedVectorAssembler[X1[Double]]
- val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[X3[Double,Int,Vector]]()
+ val dataWithFeatures =
+ vectorAssembler.transform(trainingDataDs).as[X3[Double, Int, Vector]]()
case class Input(c: Vector)
val km = TypedKMeans[Input].setK(2)
@@ -32,22 +34,27 @@ class ClusteringIntegrationTests extends FramelessMlSuite with Matchers {
)
val testData = TypedDataset.create(testSeq)
- val testDataWithFeatures = vectorAssembler.transform(testData).as[X3[Double,Int,Vector]]()
+ val testDataWithFeatures =
+ vectorAssembler.transform(testData).as[X3[Double, Int, Vector]]()
- val predictionDs = model.transform(testDataWithFeatures).as[X4[Double,Int,Vector,Int]]()
+ val predictionDs =
+ model.transform(testDataWithFeatures).as[X4[Double, Int, Vector, Int]]()
- val prediction = predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList
+ val prediction =
+ predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList
prediction mustEqual testSeq.map(_.b)
}
test("predict field2 from field1 using a bisecting K-means clustering") {
// Training
- val trainingDataDs = TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D,0))
+ val trainingDataDs =
+ TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D, 0))
val vectorAssembler = TypedVectorAssembler[X1[Double]]
- val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[X3[Double, Int, Vector]]()
+ val dataWithFeatures =
+ vectorAssembler.transform(trainingDataDs).as[X3[Double, Int, Vector]]()
case class Inputs(c: Vector)
val bkm = TypedBisectingKMeans[Inputs]().setK(2)
@@ -61,11 +68,14 @@ class ClusteringIntegrationTests extends FramelessMlSuite with Matchers {
)
val testData = TypedDataset.create(testSeq)
- val testDataWithFeatures = vectorAssembler.transform(testData).as[X3[Double, Int, Vector]]()
+ val testDataWithFeatures =
+ vectorAssembler.transform(testData).as[X3[Double, Int, Vector]]()
- val predictionDs = model.transform(testDataWithFeatures).as[X4[Double,Int,Vector,Int]]()
+ val predictionDs =
+ model.transform(testDataWithFeatures).as[X4[Double, Int, Vector, Int]]()
- val prediction = predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList
+ val prediction =
+ predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList
prediction mustEqual testSeq.map(_.b)
}
diff --git a/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala b/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala
index a41c1b703..1cb1441b9 100644
--- a/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala
+++ b/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala
@@ -3,17 +3,19 @@ package ml
package clustering
import frameless.ml.classification.TypedKMeans
-import frameless.{TypedDataset, TypedEncoder, X1, X2, X3}
+import frameless.{ TypedDataset, TypedEncoder, X1, X2, X3 }
import org.apache.spark.ml.linalg._
-import org.scalacheck.{Arbitrary, Gen}
+import org.scalacheck.{ Arbitrary, Gen }
import org.scalacheck.Prop._
import frameless.ml._
import frameless.ml.params.kmeans.KMeansInitMode
import org.scalatest.matchers.must.Matchers
class KMeansTests extends FramelessMlSuite with Matchers {
+
implicit val arbVector: Arbitrary[Vector] =
Arbitrary(Generators.arbVector.arbitrary)
+
implicit val arbKMeansInitMode: Arbitrary[KMeansInitMode] =
Arbitrary {
Gen.oneOf(
@@ -30,7 +32,7 @@ class KMeansTests extends FramelessMlSuite with Matchers {
val dense = Vectors.dense(dubs)
vect match {
case _: SparseVector => dense.toSparse
- case _ => dense
+ case _ => dense
}
}
@@ -46,17 +48,20 @@ class KMeansTests extends FramelessMlSuite with Matchers {
pDs.select(pDs.col('a)).collect().run().toList == Seq(x1.a, x1a.a)
}
- def prop3[A: TypedEncoder : Arbitrary] = forAll { x2: X2[Vector, A] =>
+ def prop3[A: TypedEncoder: Arbitrary] = forAll { x2: X2[Vector, A] =>
val x2a = x2.copy(a = newRowWithSameDimension(x2.a))
val km = TypedKMeans[X1[Vector]]
val ds = TypedDataset.create(Seq(x2, x2a))
val model = km.fit(ds).run()
val pDs = model.transform(ds).as[X3[Vector, A, Int]]()
- pDs.select(pDs.col('a), pDs.col('b)).collect().run().toList == Seq((x2.a, x2.b), (x2a.a, x2a.b))
+ pDs.select(pDs.col('a), pDs.col('b)).collect().run().toList == Seq(
+ (x2.a, x2.b),
+ (x2a.a, x2a.b)
+ )
}
- tolerantRun( _.isInstanceOf[ArrayIndexOutOfBoundsException] ) {
+ tolerantRun(_.isInstanceOf[ArrayIndexOutOfBoundsException]) {
check(prop)
check(prop3[Double])
}
@@ -76,11 +81,11 @@ class KMeansTests extends FramelessMlSuite with Matchers {
val model = rf.fit(ds).run()
model.transformer.getInitMode == KMeansInitMode.Random.sparkValue &&
- model.transformer.getInitSteps == 2 &&
- model.transformer.getK == 10 &&
- model.transformer.getMaxIter == 15 &&
- model.transformer.getSeed == 123223L &&
- model.transformer.getTol == 12D
+ model.transformer.getInitSteps == 2 &&
+ model.transformer.getK == 10 &&
+ model.transformer.getMaxIter == 15 &&
+ model.transformer.getSeed == 123223L &&
+ model.transformer.getTol == 12D
}
check(prop)
diff --git a/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala
index a27f2966d..8dba04d52 100644
--- a/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala
+++ b/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala
@@ -2,7 +2,7 @@ package frameless
package ml
package feature
-import org.scalacheck.{Arbitrary, Gen}
+import org.scalacheck.{ Arbitrary, Gen }
import org.scalacheck.Prop._
import shapeless.test.illTyped
import org.scalatest.matchers.must.Matchers
diff --git a/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala
index 18d490758..db27983a6 100644
--- a/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala
+++ b/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala
@@ -3,7 +3,7 @@ package ml
package feature
import frameless.ml.feature.TypedStringIndexer.HandleInvalid
-import org.scalacheck.{Arbitrary, Gen}
+import org.scalacheck.{ Arbitrary, Gen }
import org.scalacheck.Prop._
import shapeless.test.illTyped
import org.scalatest.matchers.must.Matchers
@@ -11,7 +11,7 @@ import org.scalatest.matchers.must.Matchers
class TypedStringIndexerTests extends FramelessMlSuite with Matchers {
test(".fit() returns a correct TypedTransformer") {
- def prop[A: TypedEncoder : Arbitrary] = forAll { x2: X2[String, A] =>
+ def prop[A: TypedEncoder: Arbitrary] = forAll { x2: X2[String, A] =>
val indexer = TypedStringIndexer[X1[String]]
val ds = TypedDataset.create(Seq(x2))
val model = indexer.fit(ds).run()
@@ -30,8 +30,8 @@ class TypedStringIndexerTests extends FramelessMlSuite with Matchers {
}
val prop = forAll { handleInvalid: HandleInvalid =>
- val indexer = TypedStringIndexer[X1[String]]
- .setHandleInvalid(handleInvalid)
+ val indexer =
+ TypedStringIndexer[X1[String]].setHandleInvalid(handleInvalid)
val ds = TypedDataset.create(Seq(X1("foo")))
val model = indexer.fit(ds).run()
diff --git a/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala
index 52cbc022b..918f4b487 100644
--- a/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala
+++ b/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala
@@ -10,23 +10,54 @@ import shapeless.test.illTyped
class TypedVectorAssemblerTests extends FramelessMlSuite {
test(".transform() returns a correct TypedTransformer") {
- def prop[A: TypedEncoder: Arbitrary] = forAll { x5: X5[Int, Long, Double, Boolean, A] =>
- val assembler = TypedVectorAssembler[X4[Int, Long, Double, Boolean]]
- val ds = TypedDataset.create(Seq(x5))
- val ds2 = assembler.transform(ds).as[X6[Int, Long, Double, Boolean, A, Vector]]()
-
- ds2.collect().run() ==
- Seq(X6(x5.a, x5.b, x5.c, x5.d, x5.e, Vectors.dense(x5.a.toDouble, x5.b.toDouble, x5.c, if (x5.d) 1D else 0D)))
- }
-
- def prop2[A: TypedEncoder: Arbitrary] = forAll { x5: X5[Boolean, BigDecimal, Byte, Short, A] =>
- val assembler = TypedVectorAssembler[X4[Boolean, BigDecimal, Byte, Short]]
- val ds = TypedDataset.create(Seq(x5))
- val ds2 = assembler.transform(ds).as[X6[Boolean, BigDecimal, Byte, Short, A, Vector]]()
-
- ds2.collect().run() ==
- Seq(X6(x5.a, x5.b, x5.c, x5.d, x5.e, Vectors.dense(if (x5.a) 1D else 0D, x5.b.toDouble, x5.c.toDouble, x5.d.toDouble)))
- }
+ def prop[A: TypedEncoder: Arbitrary] =
+ forAll { x5: X5[Int, Long, Double, Boolean, A] =>
+ val assembler = TypedVectorAssembler[X4[Int, Long, Double, Boolean]]
+ val ds = TypedDataset.create(Seq(x5))
+ val ds2 = assembler
+ .transform(ds)
+ .as[X6[Int, Long, Double, Boolean, A, Vector]]()
+
+ ds2.collect().run() ==
+ Seq(
+ X6(
+ x5.a,
+ x5.b,
+ x5.c,
+ x5.d,
+ x5.e,
+ Vectors
+ .dense(x5.a.toDouble, x5.b.toDouble, x5.c, if (x5.d) 1D else 0D)
+ )
+ )
+ }
+
+ def prop2[A: TypedEncoder: Arbitrary] =
+ forAll { x5: X5[Boolean, BigDecimal, Byte, Short, A] =>
+ val assembler =
+ TypedVectorAssembler[X4[Boolean, BigDecimal, Byte, Short]]
+ val ds = TypedDataset.create(Seq(x5))
+ val ds2 = assembler
+ .transform(ds)
+ .as[X6[Boolean, BigDecimal, Byte, Short, A, Vector]]()
+
+ ds2.collect().run() ==
+ Seq(
+ X6(
+ x5.a,
+ x5.b,
+ x5.c,
+ x5.d,
+ x5.e,
+ Vectors.dense(
+ if (x5.a) 1D else 0D,
+ x5.b.toDouble,
+ x5.c.toDouble,
+ x5.d.toDouble
+ )
+ )
+ )
+ }
check(prop[String])
check(prop[Double])
diff --git a/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala b/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala
index b3db83c74..07ebf7745 100644
--- a/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala
+++ b/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala
@@ -18,8 +18,13 @@ class RegressionIntegrationTests extends FramelessMlSuite with Matchers {
case class Features(field1: Double, field2: Int)
val vectorAssembler = TypedVectorAssembler[Features]
- case class DataWithFeatures(field1: Double, field2: Int, field3: Double, features: Vector)
- val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]()
+ case class DataWithFeatures(
+ field1: Double,
+ field2: Int,
+ field3: Double,
+ features: Vector)
+ val dataWithFeatures =
+ vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]()
case class RFInputs(field3: Double, features: Vector)
val rf = TypedRandomForestRegressor[RFInputs]
@@ -28,15 +33,28 @@ class RegressionIntegrationTests extends FramelessMlSuite with Matchers {
// Prediction
- val testData = TypedDataset.create(Seq(
- Data(0D, 10, 0D)
- ))
- val testDataWithFeatures = vectorAssembler.transform(testData).as[DataWithFeatures]()
-
- case class PredictionResult(field1: Double, field2: Int, field3: Double, features: Vector, predictedField3: Double)
- val predictionDs = model.transform(testDataWithFeatures).as[PredictionResult]()
-
- val prediction = predictionDs.select(predictionDs.col('predictedField3)).collect().run().toList
+ val testData = TypedDataset.create(
+ Seq(
+ Data(0D, 10, 0D)
+ )
+ )
+ val testDataWithFeatures =
+ vectorAssembler.transform(testData).as[DataWithFeatures]()
+
+ case class PredictionResult(
+ field1: Double,
+ field2: Int,
+ field3: Double,
+ features: Vector,
+ predictedField3: Double)
+ val predictionDs =
+ model.transform(testDataWithFeatures).as[PredictionResult]()
+
+ val prediction = predictionDs
+ .select(predictionDs.col('predictedField3))
+ .collect()
+ .run()
+ .toList
prediction mustEqual List(0D)
}
diff --git a/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala b/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala
index b864b1533..5fbccb087 100644
--- a/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala
+++ b/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala
@@ -2,7 +2,7 @@ package frameless
package ml
package regression
-import frameless.ml.params.linears.{LossStrategy, Solver}
+import frameless.ml.params.linears.{ LossStrategy, Solver }
import org.apache.spark.ml.linalg._
import org.scalacheck.Arbitrary
import org.scalacheck.Prop._
@@ -11,7 +11,9 @@ import shapeless.test.illTyped
class TypedLinearRegressionTests extends FramelessMlSuite with Matchers {
- implicit val arbVectorNonEmpty: Arbitrary[Vector] = Arbitrary(Generators.arbVector.arbitrary)
+ implicit val arbVectorNonEmpty: Arbitrary[Vector] = Arbitrary(
+ Generators.arbVector.arbitrary
+ )
test("fit() returns a correct TypedTransformer") {
val prop = forAll { x2: X2[Double, Vector] =>
@@ -32,14 +34,18 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers {
pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b)
}
- def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] =>
- val lr = TypedLinearRegression[X2[Vector, Double]]
- val ds = TypedDataset.create(Seq(x3))
- val model = lr.fit(ds).run()
- val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]()
+ def prop3[A: TypedEncoder: Arbitrary] =
+ forAll { x3: X3[Vector, Double, A] =>
+ val lr = TypedLinearRegression[X2[Vector, Double]]
+ val ds = TypedDataset.create(Seq(x3))
+ val model = lr.fit(ds).run()
+ val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]()
- pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect().run() == Seq((x3.a, x3.b, x3.c))
- }
+ pDs
+ .select(pDs.col('a), pDs.col('b), pDs.col('c))
+ .collect()
+ .run() == Seq((x3.a, x3.b, x3.c))
+ }
check(prop)
check(prop2)
@@ -48,7 +54,7 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers {
}
test("param setting is retained") {
- import Generators.{arbLossStrategy, arbSolver}
+ import Generators.{ arbLossStrategy, arbSolver }
val prop = forAll { (lossStrategy: LossStrategy, solver: Solver) =>
val lr = TypedLinearRegression[X2[Double, Vector]]
@@ -66,12 +72,12 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers {
val model = lr.fit(ds).run()
model.transformer.getAggregationDepth == 10 &&
- model.transformer.getEpsilon == 4.0 &&
- model.transformer.getLoss == lossStrategy.sparkValue &&
- model.transformer.getMaxIter == 23 &&
- model.transformer.getRegParam == 1.2 &&
- model.transformer.getTol == 2.3 &&
- model.transformer.getSolver == solver.sparkValue
+ model.transformer.getEpsilon == 4.0 &&
+ model.transformer.getLoss == lossStrategy.sparkValue &&
+ model.transformer.getMaxIter == 23 &&
+ model.transformer.getRegParam == 1.2 &&
+ model.transformer.getTol == 2.3 &&
+ model.transformer.getSolver == solver.sparkValue
}
check(prop)
@@ -98,25 +104,23 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers {
)
val ds2 = Seq(
- X3(new DenseVector(Array(1.0)): Vector,2F, 1.0),
- X3(new DenseVector(Array(2.0)): Vector,2F, 2.0),
- X3(new DenseVector(Array(3.0)): Vector,2F, 3.0),
- X3(new DenseVector(Array(4.0)): Vector,2F, 4.0),
- X3(new DenseVector(Array(5.0)): Vector,2F, 5.0),
- X3(new DenseVector(Array(6.0)): Vector,2F, 6.0)
+ X3(new DenseVector(Array(1.0)): Vector, 2F, 1.0),
+ X3(new DenseVector(Array(2.0)): Vector, 2F, 2.0),
+ X3(new DenseVector(Array(3.0)): Vector, 2F, 3.0),
+ X3(new DenseVector(Array(4.0)): Vector, 2F, 4.0),
+ X3(new DenseVector(Array(5.0)): Vector, 2F, 5.0),
+ X3(new DenseVector(Array(6.0)): Vector, 2F, 6.0)
)
val tds = TypedDataset.create(ds)
- val lr = TypedLinearRegression[X2[Vector, Double]]
- .setMaxIter(10)
+ val lr = TypedLinearRegression[X2[Vector, Double]].setMaxIter(10)
val model = lr.fit(tds).run()
val tds2 = TypedDataset.create(ds2)
- val lr2 = TypedLinearRegression[X3[Vector, Float, Double]]
- .setMaxIter(10)
+ val lr2 = TypedLinearRegression[X3[Vector, Float, Double]].setMaxIter(10)
val model2 = lr2.fit(tds2).run()
diff --git a/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala b/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala
index 4a6cd37d2..cc37a6b6e 100644
--- a/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala
+++ b/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala
@@ -10,8 +10,11 @@ import org.scalacheck.Prop._
import org.scalatest.matchers.must.Matchers
class TypedRandomForestRegressorTests extends FramelessMlSuite with Matchers {
+
implicit val arbVectorNonEmpty: Arbitrary[Vector] =
- Arbitrary(Generators.arbVector.arbitrary suchThat (_.size > 0)) // vector must not be empty for RandomForestRegressor
+ Arbitrary(
+ Generators.arbVector.arbitrary suchThat (_.size > 0)
+ ) // vector must not be empty for RandomForestRegressor
import Generators.arbTreesFeaturesSubsetStrategy
test("fit() returns a correct TypedTransformer") {
@@ -33,14 +36,18 @@ class TypedRandomForestRegressorTests extends FramelessMlSuite with Matchers {
pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b)
}
- def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] =>
- val rf = TypedRandomForestRegressor[X2[Vector, Double]]
- val ds = TypedDataset.create(Seq(x3))
- val model = rf.fit(ds).run()
- val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]()
+ def prop3[A: TypedEncoder: Arbitrary] =
+ forAll { x3: X3[Vector, Double, A] =>
+ val rf = TypedRandomForestRegressor[X2[Vector, Double]]
+ val ds = TypedDataset.create(Seq(x3))
+ val model = rf.fit(ds).run()
+ val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]()
- pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect().run() == Seq((x3.a, x3.b, x3.c))
- }
+ pDs
+ .select(pDs.col('a), pDs.col('b), pDs.col('c))
+ .collect()
+ .run() == Seq((x3.a, x3.b, x3.c))
+ }
check(prop)
check(prop2)
@@ -64,13 +71,13 @@ class TypedRandomForestRegressorTests extends FramelessMlSuite with Matchers {
val model = rf.fit(ds).run()
model.transformer.getNumTrees == 10 &&
- model.transformer.getMaxBins == 100 &&
- model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue &&
- model.transformer.getMaxDepth == 10 &&
- model.transformer.getMaxMemoryInMB == 100 &&
- model.transformer.getMinInfoGain == 0.1D &&
- model.transformer.getMinInstancesPerNode == 2 &&
- model.transformer.getSubsamplingRate == 0.9D
+ model.transformer.getMaxBins == 100 &&
+ model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue &&
+ model.transformer.getMaxDepth == 10 &&
+ model.transformer.getMaxMemoryInMB == 100 &&
+ model.transformer.getMinInfoGain == 0.1D &&
+ model.transformer.getMinInstancesPerNode == 2 &&
+ model.transformer.getSubsamplingRate == 0.9D
}
check(prop)
diff --git a/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala b/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala
index dba59454c..2dacb45d9 100644
--- a/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala
+++ b/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala
@@ -4,7 +4,10 @@ import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.{
- Invoke, NewInstance, UnwrapOption, WrapOption
+ Invoke,
+ NewInstance,
+ UnwrapOption,
+ WrapOption
}
import org.apache.spark.sql.types._
@@ -13,15 +16,16 @@ import eu.timepit.refined.api.RefType
import frameless.{ TypedEncoder, RecordFieldEncoder }
private[refined] trait RefinedFieldEncoders {
+
/**
* @tparam T the refined type (e.g. `String`)
*/
implicit def optionRefined[F[_, _], T, R](
- implicit
+ implicit
i0: RefType[F],
i1: TypedEncoder[T],
- i2: ClassTag[F[T, R]],
- ): RecordFieldEncoder[Option[F[T, R]]] =
+ i2: ClassTag[F[T, R]]
+ ): RecordFieldEncoder[Option[F[T, R]]] =
RecordFieldEncoder[Option[F[T, R]]](new TypedEncoder[Option[F[T, R]]] {
def nullable = true
@@ -54,11 +58,11 @@ private[refined] trait RefinedFieldEncoders {
* @tparam T the refined type (e.g. `String`)
*/
implicit def refined[F[_, _], T, R](
- implicit
+ implicit
i0: RefType[F],
i1: TypedEncoder[T],
- i2: ClassTag[F[T, R]],
- ): RecordFieldEncoder[F[T, R]] =
+ i2: ClassTag[F[T, R]]
+ ): RecordFieldEncoder[F[T, R]] =
RecordFieldEncoder[F[T, R]](new TypedEncoder[F[T, R]] {
def nullable = i1.nullable
@@ -76,4 +80,3 @@ private[refined] trait RefinedFieldEncoders {
override def toString = s"refined[${i2.runtimeClass.getName}]"
})
}
-
diff --git a/refined/src/main/scala/frameless/refined/package.scala b/refined/src/main/scala/frameless/refined/package.scala
index 8819be2bf..2786e14e3 100644
--- a/refined/src/main/scala/frameless/refined/package.scala
+++ b/refined/src/main/scala/frameless/refined/package.scala
@@ -5,8 +5,9 @@ import scala.reflect.ClassTag
import eu.timepit.refined.api.{ RefType, Validate }
package object refined extends RefinedFieldEncoders {
+
implicit def refinedInjection[F[_, _], T, R](
- implicit
+ implicit
refType: RefType[F],
validate: Validate[T, R]
): Injection[F[T, R], T] = Injection(
@@ -15,19 +16,20 @@ package object refined extends RefinedFieldEncoders {
refType.refine[R](value) match {
case Left(errMsg) =>
throw new IllegalArgumentException(
- s"Value $value does not satisfy refinement predicate: $errMsg")
+ s"Value $value does not satisfy refinement predicate: $errMsg"
+ )
case Right(res) => res
}
- })
+ }
+ )
implicit def refinedEncoder[F[_, _], T, R](
- implicit
+ implicit
i0: RefType[F],
i1: Validate[T, R],
i2: TypedEncoder[T],
i3: ClassTag[F[T, R]]
- ): TypedEncoder[F[T, R]] = TypedEncoder.usingInjection(
- i3, refinedInjection, i2)
+ ): TypedEncoder[F[T, R]] =
+ TypedEncoder.usingInjection(i3, refinedInjection, i2)
}
-
diff --git a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala
index 5476284ea..c152abf34 100644
--- a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala
+++ b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala
@@ -2,7 +2,11 @@ package frameless
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{
- IntegerType, ObjectType, StringType, StructField, StructType
+ IntegerType,
+ ObjectType,
+ StringType,
+ StructField,
+ StructType
}
import org.scalatest.matchers.should.Matchers
@@ -24,7 +28,8 @@ class RefinedFieldEncoderTests extends TypedDatasetSuite with Matchers {
val nes: NonEmptyString = "Non Empty String"
- val unsafeDs = TypedDataset.createUnsafe(sc.parallelize(Seq(nes.value)).toDF())(encoder)
+ val unsafeDs =
+ TypedDataset.createUnsafe(sc.parallelize(Seq(nes.value)).toDF())(encoder)
val expected = Seq(nes)
@@ -40,9 +45,12 @@ class RefinedFieldEncoderTests extends TypedDatasetSuite with Matchers {
encoderA.jvmRepr shouldBe ObjectType(classOf[A])
// Check catalystRepr
- val expectedAStructType = StructType(Seq(
- StructField("a", IntegerType, false),
- StructField("s", StringType, false)))
+ val expectedAStructType = StructType(
+ Seq(
+ StructField("a", IntegerType, false),
+ StructField("s", StringType, false)
+ )
+ )
encoderA.catalystRepr shouldBe expectedAStructType
@@ -71,18 +79,23 @@ class RefinedFieldEncoderTests extends TypedDatasetSuite with Matchers {
encoderB.jvmRepr shouldBe ObjectType(classOf[B])
// Check catalystRepr
- val expectedBStructType = StructType(Seq(
- StructField("a", IntegerType, false),
- StructField("s", StringType, true)))
+ val expectedBStructType = StructType(
+ Seq(
+ StructField("a", IntegerType, false),
+ StructField("s", StringType, true)
+ )
+ )
encoderB.catalystRepr shouldBe expectedBStructType
// Check unsafe
val unsafeDs: TypedDataset[B] = {
- val rdd = sc.parallelize(Seq(
- Row(bs.a, bs.s.mkString),
- Row(2, null.asInstanceOf[String]),
- ))
+ val rdd = sc.parallelize(
+ Seq(
+ Row(bs.a, bs.s.mkString),
+ Row(2, null.asInstanceOf[String])
+ )
+ )
val df = session.createDataFrame(rdd, expectedBStructType)