diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000..20c878bbb --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,2 @@ +# Scala Steward: Reformat with scalafmt 3.8.6 +67ab2b447c399dc16378de4d527930cc5b7f1c7a diff --git a/.scalafmt.conf b/.scalafmt.conf index 038f3d925..771bfd31a 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = 3.8.1 +version = 3.8.6 runner.dialect = scala213 newlines.beforeMultilineDef = keep diff --git a/build.sbt b/build.sbt index 4f8064cb3..77370f693 100644 --- a/build.sbt +++ b/build.sbt @@ -244,7 +244,10 @@ lazy val datasetSettings = ) }, coverageExcludedPackages := "org.apache.spark.sql.reflection", - libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude ("org.apache.hadoop", "hadoop-commons") + libraryDependencies += "com.globalmentor" % "hadoop-bare-naked-local-fs" % nakedFSVersion % Test exclude ( + "org.apache.hadoop", + "hadoop-commons" + ) ) lazy val refinedSettings = diff --git a/cats/src/main/scala/frameless/cats/FramelessSyntax.scala b/cats/src/main/scala/frameless/cats/FramelessSyntax.scala index 663ae5958..5bc5d63e7 100644 --- a/cats/src/main/scala/frameless/cats/FramelessSyntax.scala +++ b/cats/src/main/scala/frameless/cats/FramelessSyntax.scala @@ -7,18 +7,25 @@ import _root_.cats.mtl.Ask import org.apache.spark.sql.SparkSession trait FramelessSyntax extends frameless.FramelessSyntax { - implicit class SparkJobOps[F[_], A](fa: F[A])(implicit S: Sync[F], A: Ask[F, SparkSession]) { + + implicit class SparkJobOps[F[_], A]( + fa: F[A] + )(implicit + S: Sync[F], + A: Ask[F, SparkSession]) { import S._, A._ def withLocalProperty(key: String, value: String): F[A] = for { session <- ask - _ <- delay(session.sparkContext.setLocalProperty(key, value)) - a <- fa + _ <- delay(session.sparkContext.setLocalProperty(key, value)) + a <- fa } yield a - def withGroupId(groupId: String): F[A] = withLocalProperty("spark.jobGroup.id", groupId) + def withGroupId(groupId: String): F[A] = + withLocalProperty("spark.jobGroup.id", groupId) - def withDescription(description: String): F[A] = withLocalProperty("spark.job.description", description) + def withDescription(description: String): F[A] = + withLocalProperty("spark.job.description", description) } } diff --git a/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala b/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala index 524c44117..ecd893925 100644 --- a/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala +++ b/cats/src/main/scala/frameless/cats/SparkDelayInstances.scala @@ -5,7 +5,15 @@ import _root_.cats.effect.Sync import org.apache.spark.sql.SparkSession trait SparkDelayInstances { - implicit def framelessCatsSparkDelayForSync[F[_]](implicit S: Sync[F]): SparkDelay[F] = new SparkDelay[F] { - def delay[A](a: => A)(implicit spark: SparkSession): F[A] = S.delay(a) + + implicit def framelessCatsSparkDelayForSync[F[_]]( + implicit + S: Sync[F] + ): SparkDelay[F] = new SparkDelay[F] { + def delay[A]( + a: => A + )(implicit + spark: SparkSession + ): F[A] = S.delay(a) } } diff --git a/cats/src/main/scala/frameless/cats/SparkTask.scala b/cats/src/main/scala/frameless/cats/SparkTask.scala index 3a6e6330b..166835a6a 100644 --- a/cats/src/main/scala/frameless/cats/SparkTask.scala +++ b/cats/src/main/scala/frameless/cats/SparkTask.scala @@ -6,6 +6,7 @@ import _root_.cats.data.Kleisli import org.apache.spark.SparkContext object SparkTask { + def apply[A](f: SparkContext => A): SparkTask[A] = Kleisli[Id, SparkContext, A](f) diff --git a/cats/src/main/scala/frameless/cats/implicits.scala b/cats/src/main/scala/frameless/cats/implicits.scala index 1fa869a7f..fdfadf80d 100644 --- a/cats/src/main/scala/frameless/cats/implicits.scala +++ b/cats/src/main/scala/frameless/cats/implicits.scala @@ -2,7 +2,7 @@ package frameless package cats import _root_.cats._ -import _root_.cats.kernel.{CommutativeMonoid, CommutativeSemigroup} +import _root_.cats.kernel.{ CommutativeMonoid, CommutativeSemigroup } import _root_.cats.syntax.all._ import alleycats.Empty @@ -10,42 +10,78 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.RDD object implicits extends FramelessSyntax with SparkDelayInstances { + implicit class rddOps[A: ClassTag](lhs: RDD[A]) { - def csum(implicit m: CommutativeMonoid[A]): A = + + def csum( + implicit + m: CommutativeMonoid[A] + ): A = lhs.fold(m.empty)(_ |+| _) - def csumOption(implicit m: CommutativeSemigroup[A]): Option[A] = + + def csumOption( + implicit + m: CommutativeSemigroup[A] + ): Option[A] = lhs.aggregate[Option[A]](None)( (acc, a) => Some(acc.fold(a)(_ |+| a)), (l, r) => l.fold(r)(x => r.map(_ |+| x) orElse Some(x)) ) - def cmin(implicit o: Order[A], e: Empty[A]): A = { + def cmin(implicit + o: Order[A], + e: Empty[A] + ): A = { if (lhs.isEmpty()) e.empty else lhs.reduce(_ min _) } - def cminOption(implicit o: Order[A]): Option[A] = + + def cminOption( + implicit + o: Order[A] + ): Option[A] = csumOption(new CommutativeSemigroup[A] { def combine(l: A, r: A) = l min r }) - def cmax(implicit o: Order[A], e: Empty[A]): A = { + def cmax(implicit + o: Order[A], + e: Empty[A] + ): A = { if (lhs.isEmpty()) e.empty else lhs.reduce(_ max _) } - def cmaxOption(implicit o: Order[A]): Option[A] = + + def cmaxOption( + implicit + o: Order[A] + ): Option[A] = csumOption(new CommutativeSemigroup[A] { def combine(l: A, r: A) = l max r }) } implicit class pairRddOps[K: ClassTag, V: ClassTag](lhs: RDD[(K, V)]) { - def csumByKey(implicit m: CommutativeSemigroup[V]): RDD[(K, V)] = lhs.reduceByKey(_ |+| _) - def cminByKey(implicit o: Order[V]): RDD[(K, V)] = lhs.reduceByKey(_ min _) - def cmaxByKey(implicit o: Order[V]): RDD[(K, V)] = lhs.reduceByKey(_ max _) + + def csumByKey( + implicit + m: CommutativeSemigroup[V] + ): RDD[(K, V)] = lhs.reduceByKey(_ |+| _) + + def cminByKey( + implicit + o: Order[V] + ): RDD[(K, V)] = lhs.reduceByKey(_ min _) + + def cmaxByKey( + implicit + o: Order[V] + ): RDD[(K, V)] = lhs.reduceByKey(_ max _) } } object union { + implicit def unionSemigroup[A]: Semigroup[RDD[A]] = new Semigroup[RDD[A]] { def combine(lhs: RDD[A], rhs: RDD[A]): RDD[A] = lhs union rhs @@ -53,7 +89,11 @@ object union { } object inner { - implicit def pairwiseInnerSemigroup[K: ClassTag, V: ClassTag: Semigroup]: Semigroup[RDD[(K, V)]] = + + implicit def pairwiseInnerSemigroup[ + K: ClassTag, + V: ClassTag: Semigroup + ]: Semigroup[RDD[(K, V)]] = new Semigroup[RDD[(K, V)]] { def combine(lhs: RDD[(K, V)], rhs: RDD[(K, V)]): RDD[(K, V)] = lhs.join(rhs).mapValues { case (x, y) => x |+| y } @@ -61,14 +101,18 @@ object inner { } object outer { - implicit def pairwiseOuterSemigroup[K: ClassTag, V: ClassTag](implicit m: Monoid[V]): Semigroup[RDD[(K, V)]] = + + implicit def pairwiseOuterSemigroup[K: ClassTag, V: ClassTag]( + implicit + m: Monoid[V] + ): Semigroup[RDD[(K, V)]] = new Semigroup[RDD[(K, V)]] { def combine(lhs: RDD[(K, V)], rhs: RDD[(K, V)]): RDD[(K, V)] = lhs.fullOuterJoin(rhs).mapValues { case (Some(x), Some(y)) => x |+| y - case (None, Some(y)) => y - case (Some(x), None) => x - case (None, None) => m.empty + case (None, Some(y)) => y + case (Some(x), None) => x + case (None, None) => m.empty } } } diff --git a/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala b/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala index 95ffbed26..a5246ad37 100644 --- a/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala +++ b/cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala @@ -6,16 +6,18 @@ import _root_.cats.effect.IO import _root_.cats.effect.unsafe.implicits.global import org.apache.spark.sql.SparkSession import org.scalatest.matchers.should.Matchers -import org.scalacheck.{Test => PTest} +import org.scalacheck.{ Test => PTest } import org.scalacheck.Prop, Prop._ import org.scalacheck.effect.PropF, PropF._ class FramelessSyntaxTests extends TypedDatasetSuite with Matchers { override val sparkDelay = null - def prop[A, B](data: Vector[X2[A, B]])( - implicit ev: TypedEncoder[X2[A, B]] - ): Prop = { + def prop[A, B]( + data: Vector[X2[A, B]] + )(implicit + ev: TypedEncoder[X2[A, B]] + ): Prop = { import implicits._ val dataset = TypedDataset.create(data).dataset @@ -24,7 +26,13 @@ class FramelessSyntaxTests extends TypedDatasetSuite with Matchers { val typedDataset = dataset.typed val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]] - typedDataset.collect[IO]().unsafeRunSync().toVector ?= typedDatasetFromDataFrame.collect[IO]().unsafeRunSync().toVector + typedDataset + .collect[IO]() + .unsafeRunSync() + .toVector ?= typedDatasetFromDataFrame + .collect[IO]() + .unsafeRunSync() + .toVector } test("dataset typed - toTyped") { @@ -37,8 +45,7 @@ class FramelessSyntaxTests extends TypedDatasetSuite with Matchers { forAllF { (k: String, v: String) => val scopedKey = "frameless.tests." + k - 1 - .pure[ReaderT[IO, SparkSession, *]] + 1.pure[ReaderT[IO, SparkSession, *]] .withLocalProperty(scopedKey, v) .withGroupId(v) .withDescription(v) @@ -47,7 +54,8 @@ class FramelessSyntaxTests extends TypedDatasetSuite with Matchers { sc.getLocalProperty(scopedKey) shouldBe v sc.getLocalProperty("spark.jobGroup.id") shouldBe v sc.getLocalProperty("spark.job.description") shouldBe v - }.void + } + .void }.check().unsafeRunSync().status shouldBe PTest.Passed } } diff --git a/cats/src/test/scala/frameless/cats/test.scala b/cats/src/test/scala/frameless/cats/test.scala index d75bc3bfd..d9314238f 100644 --- a/cats/src/test/scala/frameless/cats/test.scala +++ b/cats/src/test/scala/frameless/cats/test.scala @@ -7,7 +7,7 @@ import _root_.cats.syntax.all._ import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSession import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext => SC} +import org.apache.spark.{ SparkConf, SparkContext => SC } import org.scalatest.compatible.Assertion import org.scalactic.anyvals.PosInt @@ -21,7 +21,11 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.propspec.AnyPropSpec trait SparkTests { - val appID: String = new java.util.Date().toString + math.floor(math.random() * 10E4).toLong.toString + + val appID: String = new java.util.Date().toString + math + .floor(math.random() * 10e4) + .toLong + .toString val conf: SparkConf = new SparkConf() .setMaster("local[*]") @@ -29,16 +33,27 @@ trait SparkTests { .set("spark.ui.enabled", "false") .set("spark.app.id", appID) - implicit def session: SparkSession = SparkSession.builder().config(conf).getOrCreate() + implicit def session: SparkSession = + SparkSession.builder().config(conf).getOrCreate() implicit def sc: SparkContext = session.sparkContext - implicit class seqToRdd[A: ClassTag](seq: Seq[A])(implicit sc: SC) { + implicit class seqToRdd[A: ClassTag]( + seq: Seq[A] + )(implicit + sc: SC) { def toRdd: RDD[A] = sc.makeRDD(seq) } } object Tests { - def innerPairwise(mx: Map[String, Int], my: Map[String, Int], check: (Any, Any) => Assertion)(implicit sc: SC): Assertion = { + + def innerPairwise( + mx: Map[String, Int], + my: Map[String, Int], + check: (Any, Any) => Assertion + )(implicit + sc: SC + ): Assertion = { import frameless.cats.implicits._ import frameless.cats.inner._ val xs = sc.parallelize(mx.toSeq) @@ -63,18 +78,27 @@ object Tests { } } -class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with SparkTests { +class Test + extends AnyPropSpec + with Matchers + with ScalaCheckPropertyChecks + with SparkTests { + implicit override val generatorDrivenConfig = PropertyCheckConfiguration(minSize = PosInt(10)) property("spark is working") { - sc.parallelize(Seq(1, 2, 3)).collect() shouldBe Array(1,2,3) + sc.parallelize(Seq(1, 2, 3)).collect() shouldBe Array(1, 2, 3) } property("inner pairwise monoid") { // Make sure we have non-empty map - forAll { (xh: (String, Int), mx: Map[String, Int], yh: (String, Int), my: Map[String, Int]) => - Tests.innerPairwise(mx + xh, my + yh, _ shouldBe _) + forAll { + (xh: (String, Int), + mx: Map[String, Int], + yh: (String, Int), + my: Map[String, Int] + ) => Tests.innerPairwise(mx + xh, my + yh, _ shouldBe _) } } @@ -110,7 +134,8 @@ class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with property("rdd tuple commutative semigroup example") { import frameless.cats.implicits._ forAll { seq: List[(Int, Int)] => - val expectedSum = if (seq.isEmpty) None else Some(Foldable[List].fold(seq)) + val expectedSum = + if (seq.isEmpty) None else Some(Foldable[List].fold(seq)) val rdd = seq.toRdd rdd.csum shouldBe expectedSum.getOrElse(0 -> 0) @@ -120,10 +145,22 @@ class Test extends AnyPropSpec with Matchers with ScalaCheckPropertyChecks with property("pair rdd numeric commutative semigroup example") { import frameless.cats.implicits._ - val seq = Seq( ("a",2), ("b",3), ("d",6), ("b",2), ("d",1) ) + val seq = Seq(("a", 2), ("b", 3), ("d", 6), ("b", 2), ("d", 1)) val rdd = seq.toRdd - rdd.cminByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",2), ("d",1) ) - rdd.cmaxByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",3), ("d",6) ) - rdd.csumByKey.collect().toSeq should contain theSameElementsAs Seq( ("a",2), ("b",5), ("d",7) ) + rdd.cminByKey.collect().toSeq should contain theSameElementsAs Seq( + ("a", 2), + ("b", 2), + ("d", 1) + ) + rdd.cmaxByKey.collect().toSeq should contain theSameElementsAs Seq( + ("a", 2), + ("b", 3), + ("d", 6) + ) + rdd.csumByKey.collect().toSeq should contain theSameElementsAs Seq( + ("a", 2), + ("b", 5), + ("d", 7) + ) } } diff --git a/core/src/main/scala/frameless/CatalystAverageable.scala b/core/src/main/scala/frameless/CatalystAverageable.scala index 401ed65fc..60a9ee347 100644 --- a/core/src/main/scala/frameless/CatalystAverageable.scala +++ b/core/src/main/scala/frameless/CatalystAverageable.scala @@ -3,24 +3,35 @@ package frameless import scala.annotation.implicitNotFound /** - * When averaging Spark doesn't change these types: - * - BigDecimal -> BigDecimal - * - Double -> Double - * But it changes these types : - * - Int -> Double - * - Short -> Double - * - Long -> Double - */ + * When averaging Spark doesn't change these types: + * - BigDecimal -> BigDecimal + * - Double -> Double + * But it changes these types : + * - Int -> Double + * - Short -> Double + * - Long -> Double + */ @implicitNotFound("Cannot compute average of type ${In}.") trait CatalystAverageable[In, Out] object CatalystAverageable { private[this] val theInstance = new CatalystAverageable[Any, Any] {} - private[this] def of[In, Out]: CatalystAverageable[In, Out] = theInstance.asInstanceOf[CatalystAverageable[In, Out]] - implicit val framelessAverageableBigDecimal: CatalystAverageable[BigDecimal, BigDecimal] = of[BigDecimal, BigDecimal] - implicit val framelessAverageableDouble: CatalystAverageable[Double, Double] = of[Double, Double] - implicit val framelessAverageableLong: CatalystAverageable[Long, Double] = of[Long, Double] - implicit val framelessAverageableInt: CatalystAverageable[Int, Double] = of[Int, Double] - implicit val framelessAverageableShort: CatalystAverageable[Short, Double] = of[Short, Double] + private[this] def of[In, Out]: CatalystAverageable[In, Out] = + theInstance.asInstanceOf[CatalystAverageable[In, Out]] + + implicit val framelessAverageableBigDecimal: CatalystAverageable[BigDecimal, BigDecimal] = + of[BigDecimal, BigDecimal] + + implicit val framelessAverageableDouble: CatalystAverageable[Double, Double] = + of[Double, Double] + + implicit val framelessAverageableLong: CatalystAverageable[Long, Double] = + of[Long, Double] + + implicit val framelessAverageableInt: CatalystAverageable[Int, Double] = + of[Int, Double] + + implicit val framelessAverageableShort: CatalystAverageable[Short, Double] = + of[Short, Double] } diff --git a/core/src/main/scala/frameless/CatalystBitShift.scala b/core/src/main/scala/frameless/CatalystBitShift.scala index 753a61907..3e1cdbefa 100644 --- a/core/src/main/scala/frameless/CatalystBitShift.scala +++ b/core/src/main/scala/frameless/CatalystBitShift.scala @@ -2,19 +2,29 @@ package frameless import scala.annotation.implicitNotFound -/** Spark does not return always Int on shift - */ +/** + * Spark does not return always Int on shift + */ @implicitNotFound("Cannot do bit shift operations on columns of type ${In}.") trait CatalystBitShift[In, Out] object CatalystBitShift { private[this] val theInstance = new CatalystBitShift[Any, Any] {} - private[this] def of[In, Out]: CatalystBitShift[In, Out] = theInstance.asInstanceOf[CatalystBitShift[In, Out]] - implicit val framelessBitShiftBigDecimal: CatalystBitShift[BigDecimal, Int] = of[BigDecimal, Int] - implicit val framelessBitShiftDouble : CatalystBitShift[Byte, Int] = of[Byte, Int] - implicit val framelessBitShiftInt : CatalystBitShift[Short, Int] = of[Short, Int] - implicit val framelessBitShiftLong : CatalystBitShift[Int, Int] = of[Int, Int] - implicit val framelessBitShiftShort : CatalystBitShift[Long, Long] = of[Long, Long] + private[this] def of[In, Out]: CatalystBitShift[In, Out] = + theInstance.asInstanceOf[CatalystBitShift[In, Out]] + + implicit val framelessBitShiftBigDecimal: CatalystBitShift[BigDecimal, Int] = + of[BigDecimal, Int] + + implicit val framelessBitShiftDouble: CatalystBitShift[Byte, Int] = + of[Byte, Int] + + implicit val framelessBitShiftInt: CatalystBitShift[Short, Int] = + of[Short, Int] + implicit val framelessBitShiftLong: CatalystBitShift[Int, Int] = of[Int, Int] + + implicit val framelessBitShiftShort: CatalystBitShift[Long, Long] = + of[Long, Long] } diff --git a/core/src/main/scala/frameless/CatalystCast.scala b/core/src/main/scala/frameless/CatalystCast.scala index 1a8a21573..94fa59a69 100644 --- a/core/src/main/scala/frameless/CatalystCast.scala +++ b/core/src/main/scala/frameless/CatalystCast.scala @@ -4,29 +4,62 @@ trait CatalystCast[A, B] object CatalystCast { private[this] val theInstance = new CatalystCast[Any, Any] {} - private[this] def of[A, B]: CatalystCast[A, B] = theInstance.asInstanceOf[CatalystCast[A, B]] + + private[this] def of[A, B]: CatalystCast[A, B] = + theInstance.asInstanceOf[CatalystCast[A, B]] implicit def framelessCastToString[T]: CatalystCast[T, String] = of[T, String] - implicit def framelessNumericToLong [A: CatalystNumeric]: CatalystCast[A, Long] = of[A, Long] - implicit def framelessNumericToInt [A: CatalystNumeric]: CatalystCast[A, Int] = of[A, Int] - implicit def framelessNumericToShort [A: CatalystNumeric]: CatalystCast[A, Short] = of[A, Short] - implicit def framelessNumericToByte [A: CatalystNumeric]: CatalystCast[A, Byte] = of[A, Byte] - implicit def framelessNumericToDecimal[A: CatalystNumeric]: CatalystCast[A, BigDecimal] = of[A, BigDecimal] - implicit def framelessNumericToDouble [A: CatalystNumeric]: CatalystCast[A, Double] = of[A, Double] + implicit def framelessNumericToLong[ + A: CatalystNumeric + ]: CatalystCast[A, Long] = of[A, Long] + + implicit def framelessNumericToInt[A: CatalystNumeric]: CatalystCast[A, Int] = + of[A, Int] + + implicit def framelessNumericToShort[ + A: CatalystNumeric + ]: CatalystCast[A, Short] = of[A, Short] + + implicit def framelessNumericToByte[ + A: CatalystNumeric + ]: CatalystCast[A, Byte] = of[A, Byte] - implicit def framelessBooleanToNumeric[A: CatalystNumeric]: CatalystCast[Boolean, A] = of[Boolean, A] + implicit def framelessNumericToDecimal[ + A: CatalystNumeric + ]: CatalystCast[A, BigDecimal] = of[A, BigDecimal] + + implicit def framelessNumericToDouble[ + A: CatalystNumeric + ]: CatalystCast[A, Double] = of[A, Double] + + implicit def framelessBooleanToNumeric[ + A: CatalystNumeric + ]: CatalystCast[Boolean, A] = of[Boolean, A] // doesn't make any sense to include: // - sqlDateToBoolean: always None // - sqlTimestampToBoolean: compares us to 0 - implicit val framelessStringToBoolean : CatalystCast[String, Option[Boolean]] = of[String, Option[Boolean]] - implicit val framelessLongToBoolean : CatalystCast[Long, Boolean] = of[Long, Boolean] - implicit val framelessIntToBoolean : CatalystCast[Int, Boolean] = of[Int, Boolean] - implicit val framelessShortToBoolean : CatalystCast[Short, Boolean] = of[Short, Boolean] - implicit val framelessByteToBoolean : CatalystCast[Byte, Boolean] = of[Byte, Boolean] - implicit val framelessBigDecimalToBoolean: CatalystCast[BigDecimal, Boolean] = of[BigDecimal, Boolean] - implicit val framelessDoubleToBoolean : CatalystCast[Double, Boolean] = of[Double, Boolean] + implicit val framelessStringToBoolean: CatalystCast[String, Option[Boolean]] = + of[String, Option[Boolean]] + + implicit val framelessLongToBoolean: CatalystCast[Long, Boolean] = + of[Long, Boolean] + + implicit val framelessIntToBoolean: CatalystCast[Int, Boolean] = + of[Int, Boolean] + + implicit val framelessShortToBoolean: CatalystCast[Short, Boolean] = + of[Short, Boolean] + + implicit val framelessByteToBoolean: CatalystCast[Byte, Boolean] = + of[Byte, Boolean] + + implicit val framelessBigDecimalToBoolean: CatalystCast[BigDecimal, Boolean] = + of[BigDecimal, Boolean] + + implicit val framelessDoubleToBoolean: CatalystCast[Double, Boolean] = + of[Double, Boolean] // TODO @@ -38,9 +71,8 @@ object CatalystCast { // implicit object stringToLong extends CatalystCast[String, Option[Long]] // implicit object stringToSqlDate extends CatalystCast[String, Option[SQLDate]] - // needs verification: - //implicit object sqlTimestampToSqlDate extends CatalystCast[SQLTimestamp, SQLDate] + // implicit object sqlTimestampToSqlDate extends CatalystCast[SQLTimestamp, SQLDate] // needs verification: // implicit object sqlTimestampToDecimal extends CatalystCast[SQLTimestamp, BigDecimal] diff --git a/core/src/main/scala/frameless/CatalystCollection.scala b/core/src/main/scala/frameless/CatalystCollection.scala index 3456869a0..731a7385d 100644 --- a/core/src/main/scala/frameless/CatalystCollection.scala +++ b/core/src/main/scala/frameless/CatalystCollection.scala @@ -7,10 +7,12 @@ trait CatalystCollection[C[_]] object CatalystCollection { private[this] val theInstance = new CatalystCollection[Any] {} - private[this] def of[A[_]]: CatalystCollection[A] = theInstance.asInstanceOf[CatalystCollection[A]] - implicit val arrayObject : CatalystCollection[Array] = of[Array] - implicit val seqObject : CatalystCollection[Seq] = of[Seq] - implicit val listObject : CatalystCollection[List] = of[List] + private[this] def of[A[_]]: CatalystCollection[A] = + theInstance.asInstanceOf[CatalystCollection[A]] + + implicit val arrayObject: CatalystCollection[Array] = of[Array] + implicit val seqObject: CatalystCollection[Seq] = of[Seq] + implicit val listObject: CatalystCollection[List] = of[List] implicit val vectorObject: CatalystCollection[Vector] = of[Vector] } diff --git a/core/src/main/scala/frameless/CatalystDivisible.scala b/core/src/main/scala/frameless/CatalystDivisible.scala index c9080a5d8..01d849174 100644 --- a/core/src/main/scala/frameless/CatalystDivisible.scala +++ b/core/src/main/scala/frameless/CatalystDivisible.scala @@ -2,20 +2,34 @@ package frameless import scala.annotation.implicitNotFound -/** Spark divides everything as Double, expect BigDecimals are divided into - * another BigDecimal, benefiting from some added precision. - */ +/** + * Spark divides everything as Double, expect BigDecimals are divided into + * another BigDecimal, benefiting from some added precision. + */ @implicitNotFound("Cannot compute division on type ${In}.") trait CatalystDivisible[In, Out] object CatalystDivisible { private[this] val theInstance = new CatalystDivisible[Any, Any] {} - private[this] def of[In, Out]: CatalystDivisible[In, Out] = theInstance.asInstanceOf[CatalystDivisible[In, Out]] - - implicit val framelessDivisibleBigDecimal: CatalystDivisible[BigDecimal, BigDecimal] = of[BigDecimal, BigDecimal] - implicit val framelessDivisibleDouble : CatalystDivisible[Double, Double] = of[Double, Double] - implicit val framelessDivisibleInt : CatalystDivisible[Int, Double] = of[Int, Double] - implicit val framelessDivisibleLong : CatalystDivisible[Long, Double] = of[Long, Double] - implicit val framelessDivisibleByte : CatalystDivisible[Byte, Double] = of[Byte, Double] - implicit val framelessDivisibleShort : CatalystDivisible[Short, Double] = of[Short, Double] + + private[this] def of[In, Out]: CatalystDivisible[In, Out] = + theInstance.asInstanceOf[CatalystDivisible[In, Out]] + + implicit val framelessDivisibleBigDecimal: CatalystDivisible[BigDecimal, BigDecimal] = + of[BigDecimal, BigDecimal] + + implicit val framelessDivisibleDouble: CatalystDivisible[Double, Double] = + of[Double, Double] + + implicit val framelessDivisibleInt: CatalystDivisible[Int, Double] = + of[Int, Double] + + implicit val framelessDivisibleLong: CatalystDivisible[Long, Double] = + of[Long, Double] + + implicit val framelessDivisibleByte: CatalystDivisible[Byte, Double] = + of[Byte, Double] + + implicit val framelessDivisibleShort: CatalystDivisible[Short, Double] = + of[Short, Double] } diff --git a/core/src/main/scala/frameless/CatalystIsin.scala b/core/src/main/scala/frameless/CatalystIsin.scala index f630a7155..fe12ab622 100644 --- a/core/src/main/scala/frameless/CatalystIsin.scala +++ b/core/src/main/scala/frameless/CatalystIsin.scala @@ -8,11 +8,11 @@ trait CatalystIsin[A] object CatalystIsin { implicit object framelessBigDecimal extends CatalystIsin[BigDecimal] - implicit object framelessByte extends CatalystIsin[Byte] - implicit object framelessDouble extends CatalystIsin[Double] - implicit object framelessFloat extends CatalystIsin[Float] - implicit object framelessInt extends CatalystIsin[Int] - implicit object framelessLong extends CatalystIsin[Long] - implicit object framelessShort extends CatalystIsin[Short] - implicit object framelesssString extends CatalystIsin[String] + implicit object framelessByte extends CatalystIsin[Byte] + implicit object framelessDouble extends CatalystIsin[Double] + implicit object framelessFloat extends CatalystIsin[Float] + implicit object framelessInt extends CatalystIsin[Int] + implicit object framelessLong extends CatalystIsin[Long] + implicit object framelessShort extends CatalystIsin[Short] + implicit object framelesssString extends CatalystIsin[String] } diff --git a/core/src/main/scala/frameless/CatalystNaN.scala b/core/src/main/scala/frameless/CatalystNaN.scala index 3e7be8263..56549ed24 100644 --- a/core/src/main/scala/frameless/CatalystNaN.scala +++ b/core/src/main/scala/frameless/CatalystNaN.scala @@ -8,9 +8,10 @@ trait CatalystNaN[A] object CatalystNaN { private[this] val theInstance = new CatalystNaN[Any] {} - private[this] def of[A]: CatalystNaN[A] = theInstance.asInstanceOf[CatalystNaN[A]] - implicit val framelessFloatNaN : CatalystNaN[Float] = of[Float] - implicit val framelessDoubleNaN : CatalystNaN[Double] = of[Double] -} + private[this] def of[A]: CatalystNaN[A] = + theInstance.asInstanceOf[CatalystNaN[A]] + implicit val framelessFloatNaN: CatalystNaN[Float] = of[Float] + implicit val framelessDoubleNaN: CatalystNaN[Double] = of[Double] +} diff --git a/core/src/main/scala/frameless/CatalystNotNullable.scala b/core/src/main/scala/frameless/CatalystNotNullable.scala index e8d4b3be1..58636455d 100644 --- a/core/src/main/scala/frameless/CatalystNotNullable.scala +++ b/core/src/main/scala/frameless/CatalystNotNullable.scala @@ -6,13 +6,20 @@ import scala.annotation.implicitNotFound trait CatalystNullable[A] object CatalystNullable { - implicit def optionIsNullable[A]: CatalystNullable[Option[A]] = new CatalystNullable[Option[A]] {} + + implicit def optionIsNullable[A]: CatalystNullable[Option[A]] = + new CatalystNullable[Option[A]] {} } @implicitNotFound("Cannot find evidence that type ${A} is not nullable.") trait NotCatalystNullable[A] object NotCatalystNullable { - implicit def everythingIsNotNullable[A]: NotCatalystNullable[A] = new NotCatalystNullable[A] {} - implicit def nullableIsNotNotNullable[A: CatalystNullable]: NotCatalystNullable[A] = new NotCatalystNullable[A] {} + + implicit def everythingIsNotNullable[A]: NotCatalystNullable[A] = + new NotCatalystNullable[A] {} + + implicit def nullableIsNotNotNullable[ + A: CatalystNullable + ]: NotCatalystNullable[A] = new NotCatalystNullable[A] {} } diff --git a/core/src/main/scala/frameless/CatalystNumeric.scala b/core/src/main/scala/frameless/CatalystNumeric.scala index c819ba2ae..fd3aa027e 100644 --- a/core/src/main/scala/frameless/CatalystNumeric.scala +++ b/core/src/main/scala/frameless/CatalystNumeric.scala @@ -8,12 +8,15 @@ trait CatalystNumeric[A] object CatalystNumeric { private[this] val theInstance = new CatalystNumeric[Any] {} - private[this] def of[A]: CatalystNumeric[A] = theInstance.asInstanceOf[CatalystNumeric[A]] - implicit val framelessbigDecimalNumeric: CatalystNumeric[BigDecimal] = of[BigDecimal] - implicit val framelessbyteNumeric : CatalystNumeric[Byte] = of[Byte] - implicit val framelessdoubleNumeric : CatalystNumeric[Double] = of[Double] - implicit val framelessintNumeric : CatalystNumeric[Int] = of[Int] - implicit val framelesslongNumeric : CatalystNumeric[Long] = of[Long] - implicit val framelessshortNumeric : CatalystNumeric[Short] = of[Short] + private[this] def of[A]: CatalystNumeric[A] = + theInstance.asInstanceOf[CatalystNumeric[A]] + + implicit val framelessbigDecimalNumeric: CatalystNumeric[BigDecimal] = + of[BigDecimal] + implicit val framelessbyteNumeric: CatalystNumeric[Byte] = of[Byte] + implicit val framelessdoubleNumeric: CatalystNumeric[Double] = of[Double] + implicit val framelessintNumeric: CatalystNumeric[Int] = of[Int] + implicit val framelesslongNumeric: CatalystNumeric[Long] = of[Long] + implicit val framelessshortNumeric: CatalystNumeric[Short] = of[Short] } diff --git a/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala b/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala index 8fee63be2..08f69def0 100644 --- a/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala +++ b/core/src/main/scala/frameless/CatalystNumericWithJavaBigDecimal.scala @@ -2,20 +2,36 @@ package frameless import scala.annotation.implicitNotFound -/** Spark does not return always the same type as the input was for example abs - */ +/** + * Spark does not return always the same type as the input was for example abs + */ @implicitNotFound("Cannot compute on type ${In}.") trait CatalystNumericWithJavaBigDecimal[In, Out] object CatalystNumericWithJavaBigDecimal { - private[this] val theInstance = new CatalystNumericWithJavaBigDecimal[Any, Any] {} - private[this] def of[In, Out]: CatalystNumericWithJavaBigDecimal[In, Out] = theInstance.asInstanceOf[CatalystNumericWithJavaBigDecimal[In, Out]] - implicit val framelessAbsoluteBigDecimal: CatalystNumericWithJavaBigDecimal[BigDecimal, java.math.BigDecimal] = of[BigDecimal, java.math.BigDecimal] - implicit val framelessAbsoluteDouble : CatalystNumericWithJavaBigDecimal[Double, Double] = of[Double, Double] - implicit val framelessAbsoluteInt : CatalystNumericWithJavaBigDecimal[Int, Int] = of[Int, Int] - implicit val framelessAbsoluteLong : CatalystNumericWithJavaBigDecimal[Long, Long] = of[Long, Long] - implicit val framelessAbsoluteShort : CatalystNumericWithJavaBigDecimal[Short, Short] = of[Short, Short] - implicit val framelessAbsoluteByte : CatalystNumericWithJavaBigDecimal[Byte, Byte] = of[Byte, Byte] + private[this] val theInstance = + new CatalystNumericWithJavaBigDecimal[Any, Any] {} -} \ No newline at end of file + private[this] def of[In, Out]: CatalystNumericWithJavaBigDecimal[In, Out] = + theInstance.asInstanceOf[CatalystNumericWithJavaBigDecimal[In, Out]] + + implicit val framelessAbsoluteBigDecimal: CatalystNumericWithJavaBigDecimal[BigDecimal, java.math.BigDecimal] = + of[BigDecimal, java.math.BigDecimal] + + implicit val framelessAbsoluteDouble: CatalystNumericWithJavaBigDecimal[Double, Double] = + of[Double, Double] + + implicit val framelessAbsoluteInt: CatalystNumericWithJavaBigDecimal[Int, Int] = + of[Int, Int] + + implicit val framelessAbsoluteLong: CatalystNumericWithJavaBigDecimal[Long, Long] = + of[Long, Long] + + implicit val framelessAbsoluteShort: CatalystNumericWithJavaBigDecimal[Short, Short] = + of[Short, Short] + + implicit val framelessAbsoluteByte: CatalystNumericWithJavaBigDecimal[Byte, Byte] = + of[Byte, Byte] + +} diff --git a/core/src/main/scala/frameless/CatalystOrdered.scala b/core/src/main/scala/frameless/CatalystOrdered.scala index e73604909..119ba248e 100644 --- a/core/src/main/scala/frameless/CatalystOrdered.scala +++ b/core/src/main/scala/frameless/CatalystOrdered.scala @@ -1,9 +1,9 @@ package frameless import scala.annotation.implicitNotFound -import shapeless.{Generic, HList, Lazy} +import shapeless.{ Generic, HList, Lazy } import shapeless.ops.hlist.LiftAll -import java.time.{Duration, Instant, Period} +import java.time.{ Duration, Instant, Period } /** Types that can be ordered/compared by Catalyst. */ @implicitNotFound("Cannot compare columns of type ${A}.") @@ -11,31 +11,39 @@ trait CatalystOrdered[A] object CatalystOrdered { private[this] val theInstance = new CatalystOrdered[Any] {} - private[this] def of[A]: CatalystOrdered[A] = theInstance.asInstanceOf[CatalystOrdered[A]] - - implicit val framelessIntOrdered : CatalystOrdered[Int] = of[Int] - implicit val framelessBooleanOrdered : CatalystOrdered[Boolean] = of[Boolean] - implicit val framelessByteOrdered : CatalystOrdered[Byte] = of[Byte] - implicit val framelessShortOrdered : CatalystOrdered[Short] = of[Short] - implicit val framelessLongOrdered : CatalystOrdered[Long] = of[Long] - implicit val framelessFloatOrdered : CatalystOrdered[Float] = of[Float] - implicit val framelessDoubleOrdered : CatalystOrdered[Double] = of[Double] - implicit val framelessBigDecimalOrdered : CatalystOrdered[BigDecimal] = of[BigDecimal] - implicit val framelessSQLDateOrdered : CatalystOrdered[SQLDate] = of[SQLDate] - implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] = of[SQLTimestamp] - implicit val framelessStringOrdered : CatalystOrdered[String] = of[String] - implicit val framelessInstantOrdered : CatalystOrdered[Instant] = of[Instant] - implicit val framelessDurationOrdered : CatalystOrdered[Duration] = of[Duration] - implicit val framelessPeriodOrdered : CatalystOrdered[Period] = of[Period] - - implicit def injectionOrdered[A, B] - (implicit + + private[this] def of[A]: CatalystOrdered[A] = + theInstance.asInstanceOf[CatalystOrdered[A]] + + implicit val framelessIntOrdered: CatalystOrdered[Int] = of[Int] + implicit val framelessBooleanOrdered: CatalystOrdered[Boolean] = of[Boolean] + implicit val framelessByteOrdered: CatalystOrdered[Byte] = of[Byte] + implicit val framelessShortOrdered: CatalystOrdered[Short] = of[Short] + implicit val framelessLongOrdered: CatalystOrdered[Long] = of[Long] + implicit val framelessFloatOrdered: CatalystOrdered[Float] = of[Float] + implicit val framelessDoubleOrdered: CatalystOrdered[Double] = of[Double] + + implicit val framelessBigDecimalOrdered: CatalystOrdered[BigDecimal] = + of[BigDecimal] + implicit val framelessSQLDateOrdered: CatalystOrdered[SQLDate] = of[SQLDate] + + implicit val framelessSQLTimestampOrdered: CatalystOrdered[SQLTimestamp] = + of[SQLTimestamp] + implicit val framelessStringOrdered: CatalystOrdered[String] = of[String] + implicit val framelessInstantOrdered: CatalystOrdered[Instant] = of[Instant] + + implicit val framelessDurationOrdered: CatalystOrdered[Duration] = + of[Duration] + implicit val framelessPeriodOrdered: CatalystOrdered[Period] = of[Period] + + implicit def injectionOrdered[A, B]( + implicit i0: Injection[A, B], i1: CatalystOrdered[B] ): CatalystOrdered[A] = of[A] - implicit def deriveGeneric[G, H <: HList] - (implicit + implicit def deriveGeneric[G, H <: HList]( + implicit i0: Generic.Aux[G, H], i1: Lazy[LiftAll[CatalystOrdered, H]] ): CatalystOrdered[G] = of[G] diff --git a/core/src/main/scala/frameless/CatalystPivotable.scala b/core/src/main/scala/frameless/CatalystPivotable.scala index a7b34da64..091d82818 100644 --- a/core/src/main/scala/frameless/CatalystPivotable.scala +++ b/core/src/main/scala/frameless/CatalystPivotable.scala @@ -7,10 +7,14 @@ trait CatalystPivotable[A] object CatalystPivotable { private[this] val theInstance = new CatalystPivotable[Any] {} - private[this] def of[A]: CatalystPivotable[A] = theInstance.asInstanceOf[CatalystPivotable[A]] - implicit val framelessIntPivotable : CatalystPivotable[Int] = of[Int] - implicit val framelessLongPivotable : CatalystPivotable[Long] = of[Long] - implicit val framelessBooleanPivotable: CatalystPivotable[Boolean] = of[Boolean] - implicit val framelessStringPivotable : CatalystPivotable[String] = of[String] + private[this] def of[A]: CatalystPivotable[A] = + theInstance.asInstanceOf[CatalystPivotable[A]] + + implicit val framelessIntPivotable: CatalystPivotable[Int] = of[Int] + implicit val framelessLongPivotable: CatalystPivotable[Long] = of[Long] + + implicit val framelessBooleanPivotable: CatalystPivotable[Boolean] = + of[Boolean] + implicit val framelessStringPivotable: CatalystPivotable[String] = of[String] } diff --git a/core/src/main/scala/frameless/CatalystRound.scala b/core/src/main/scala/frameless/CatalystRound.scala index ee50b794a..92b5c1c54 100644 --- a/core/src/main/scala/frameless/CatalystRound.scala +++ b/core/src/main/scala/frameless/CatalystRound.scala @@ -2,18 +2,22 @@ package frameless import scala.annotation.implicitNotFound -/** Spark does not return always long on round - */ +/** + * Spark does not return always long on round + */ @implicitNotFound("Cannot compute round on type ${In}.") trait CatalystRound[In, Out] object CatalystRound { private[this] val theInstance = new CatalystRound[Any, Any] {} - private[this] def of[In, Out]: CatalystRound[In, Out] = theInstance.asInstanceOf[CatalystRound[In, Out]] - implicit val framelessBigDecimal: CatalystRound[BigDecimal, java.math.BigDecimal] = of[BigDecimal, java.math.BigDecimal] - implicit val framelessDouble : CatalystRound[Double, Long] = of[Double, Long] - implicit val framelessInt : CatalystRound[Int, Long] = of[Int, Long] - implicit val framelessLong : CatalystRound[Long, Long] = of[Long, Long] - implicit val framelessShort : CatalystRound[Short, Long] = of[Short, Long] -} \ No newline at end of file + private[this] def of[In, Out]: CatalystRound[In, Out] = + theInstance.asInstanceOf[CatalystRound[In, Out]] + + implicit val framelessBigDecimal: CatalystRound[BigDecimal, java.math.BigDecimal] = + of[BigDecimal, java.math.BigDecimal] + implicit val framelessDouble: CatalystRound[Double, Long] = of[Double, Long] + implicit val framelessInt: CatalystRound[Int, Long] = of[Int, Long] + implicit val framelessLong: CatalystRound[Long, Long] = of[Long, Long] + implicit val framelessShort: CatalystRound[Short, Long] = of[Short, Long] +} diff --git a/core/src/main/scala/frameless/CatalystSummable.scala b/core/src/main/scala/frameless/CatalystSummable.scala index 94010505e..8de1bb6f3 100644 --- a/core/src/main/scala/frameless/CatalystSummable.scala +++ b/core/src/main/scala/frameless/CatalystSummable.scala @@ -3,29 +3,39 @@ package frameless import scala.annotation.implicitNotFound /** - * When summing Spark doesn't change these types: - * - Long -> Long - * - BigDecimal -> BigDecimal - * - Double -> Double - * - * For other types there are conversions: - * - Int -> Long - * - Short -> Long - */ + * When summing Spark doesn't change these types: + * - Long -> Long + * - BigDecimal -> BigDecimal + * - Double -> Double + * + * For other types there are conversions: + * - Int -> Long + * - Short -> Long + */ @implicitNotFound("Cannot compute sum of type ${In}.") trait CatalystSummable[In, Out] { def zero: In } object CatalystSummable { + def apply[In, Out](zero: In): CatalystSummable[In, Out] = { val _zero = zero new CatalystSummable[In, Out] { val zero: In = _zero } } - implicit val framelessSummableLong : CatalystSummable[Long, Long] = CatalystSummable(zero = 0L) - implicit val framelessSummableBigDecimal: CatalystSummable[BigDecimal, BigDecimal] = CatalystSummable(zero = BigDecimal(0)) - implicit val framelessSummableDouble : CatalystSummable[Double, Double] = CatalystSummable(zero = 0.0) - implicit val framelessSummableInt : CatalystSummable[Int, Long] = CatalystSummable(zero = 0) - implicit val framelessSummableShort : CatalystSummable[Short, Long] = CatalystSummable(zero = 0) + implicit val framelessSummableLong: CatalystSummable[Long, Long] = + CatalystSummable(zero = 0L) + + implicit val framelessSummableBigDecimal: CatalystSummable[BigDecimal, BigDecimal] = + CatalystSummable(zero = BigDecimal(0)) + + implicit val framelessSummableDouble: CatalystSummable[Double, Double] = + CatalystSummable(zero = 0.0) + + implicit val framelessSummableInt: CatalystSummable[Int, Long] = + CatalystSummable(zero = 0) + + implicit val framelessSummableShort: CatalystSummable[Short, Long] = + CatalystSummable(zero = 0) } diff --git a/core/src/main/scala/frameless/CatalystVariance.scala b/core/src/main/scala/frameless/CatalystVariance.scala index 9e843fa70..e9f935549 100644 --- a/core/src/main/scala/frameless/CatalystVariance.scala +++ b/core/src/main/scala/frameless/CatalystVariance.scala @@ -3,18 +3,22 @@ package frameless import scala.annotation.implicitNotFound /** - * Spark's variance and stddev functions always return Double - */ + * Spark's variance and stddev functions always return Double + */ @implicitNotFound("Cannot compute variance on type ${A}.") trait CatalystVariance[A] object CatalystVariance { private[this] val theInstance = new CatalystVariance[Any] {} - private[this] def of[A]: CatalystVariance[A] = theInstance.asInstanceOf[CatalystVariance[A]] - implicit val framelessIntVariance : CatalystVariance[Int] = of[Int] - implicit val framelessLongVariance : CatalystVariance[Long] = of[Long] - implicit val framelessShortVariance : CatalystVariance[Short] = of[Short] - implicit val framelessBigDecimalVariance: CatalystVariance[BigDecimal] = of[BigDecimal] - implicit val framelessDoubleVariance : CatalystVariance[Double] = of[Double] + private[this] def of[A]: CatalystVariance[A] = + theInstance.asInstanceOf[CatalystVariance[A]] + + implicit val framelessIntVariance: CatalystVariance[Int] = of[Int] + implicit val framelessLongVariance: CatalystVariance[Long] = of[Long] + implicit val framelessShortVariance: CatalystVariance[Short] = of[Short] + + implicit val framelessBigDecimalVariance: CatalystVariance[BigDecimal] = + of[BigDecimal] + implicit val framelessDoubleVariance: CatalystVariance[Double] = of[Double] } diff --git a/dataset/src/main/scala/frameless/FramelessSyntax.scala b/dataset/src/main/scala/frameless/FramelessSyntax.scala index 5ba294921..3967314df 100644 --- a/dataset/src/main/scala/frameless/FramelessSyntax.scala +++ b/dataset/src/main/scala/frameless/FramelessSyntax.scala @@ -1,18 +1,25 @@ package frameless -import org.apache.spark.sql.{Column, DataFrame, Dataset} +import org.apache.spark.sql.{ Column, DataFrame, Dataset } trait FramelessSyntax { + implicit class ColumnSyntax(self: Column) { - def typedColumn[T, U: TypedEncoder]: TypedColumn[T, U] = new TypedColumn[T, U](self) - def typedAggregate[T, U: TypedEncoder]: TypedAggregate[T, U] = new TypedAggregate[T, U](self) + + def typedColumn[T, U: TypedEncoder]: TypedColumn[T, U] = + new TypedColumn[T, U](self) + + def typedAggregate[T, U: TypedEncoder]: TypedAggregate[T, U] = + new TypedAggregate[T, U](self) } implicit class DatasetSyntax[T: TypedEncoder](self: Dataset[T]) { def typed: TypedDataset[T] = TypedDataset.create[T](self) } - implicit class DataframeSyntax(self: DataFrame){ - def unsafeTyped[T: TypedEncoder]: TypedDataset[T] = TypedDataset.createUnsafe(self) + implicit class DataframeSyntax(self: DataFrame) { + + def unsafeTyped[T: TypedEncoder]: TypedDataset[T] = + TypedDataset.createUnsafe(self) } } diff --git a/dataset/src/main/scala/frameless/InjectionEnum.scala b/dataset/src/main/scala/frameless/InjectionEnum.scala index 4ed1006e3..677b321ef 100644 --- a/dataset/src/main/scala/frameless/InjectionEnum.scala +++ b/dataset/src/main/scala/frameless/InjectionEnum.scala @@ -3,6 +3,7 @@ package frameless import shapeless._ trait InjectionEnum { + implicit val cnilInjectionEnum: Injection[CNil, String] = Injection( // $COVERAGE-OFF$No value of type CNil so impossible to test @@ -15,10 +16,10 @@ trait InjectionEnum { ) implicit def coproductInjectionEnum[H, T <: Coproduct]( - implicit - typeable: Typeable[H] , - gen: Generic.Aux[H, HNil], - tInjectionEnum: Injection[T, String] + implicit + typeable: Typeable[H], + gen: Generic.Aux[H, HNil], + tInjectionEnum: Injection[T, String] ): Injection[H :+: T, String] = { val dataConstructorName = typeable.describe.takeWhile(_ != '.') @@ -37,9 +38,9 @@ trait InjectionEnum { } implicit def genericInjectionEnum[A, R]( - implicit - gen: Generic.Aux[A, R], - rInjectionEnum: Injection[R, String] + implicit + gen: Generic.Aux[A, R], + rInjectionEnum: Injection[R, String] ): Injection[A, String] = Injection( value => rInjectionEnum(gen.to(value)), diff --git a/dataset/src/main/scala/frameless/IsValueClass.scala b/dataset/src/main/scala/frameless/IsValueClass.scala index 78605c130..c8097e785 100644 --- a/dataset/src/main/scala/frameless/IsValueClass.scala +++ b/dataset/src/main/scala/frameless/IsValueClass.scala @@ -5,13 +5,18 @@ import shapeless.labelled.FieldType /** Evidence that `T` is a Value class */ @annotation.implicitNotFound(msg = "${T} is not a Value class") -final class IsValueClass[T] private() {} +final class IsValueClass[T] private () {} object IsValueClass { + /** Provides an evidence `A` is a Value class */ - implicit def apply[A <: AnyVal, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil]]( - implicit + implicit def apply[ + A <: AnyVal, + G <: ::[_, HNil], + H <: ::[_ <: FieldType[_ <: Symbol, _], HNil] + ](implicit i0: LabelledGeneric.Aux[A, G], - i1: DropUnitValues.Aux[G, H]): IsValueClass[A] = new IsValueClass[A] + i1: DropUnitValues.Aux[G, H] + ): IsValueClass[A] = new IsValueClass[A] } diff --git a/dataset/src/main/scala/frameless/Job.scala b/dataset/src/main/scala/frameless/Job.scala index 40931b8b4..ead6ecf09 100644 --- a/dataset/src/main/scala/frameless/Job.scala +++ b/dataset/src/main/scala/frameless/Job.scala @@ -2,7 +2,10 @@ package frameless import org.apache.spark.sql.SparkSession -sealed abstract class Job[A](implicit spark: SparkSession) { self => +sealed abstract class Job[A]( + implicit + spark: SparkSession) { self => + /** Runs a new Spark job. */ def run(): A @@ -32,13 +35,22 @@ sealed abstract class Job[A](implicit spark: SparkSession) { self => } } - object Job { - def apply[A](a: => A)(implicit spark: SparkSession): Job[A] = new Job[A] { + + def apply[A]( + a: => A + )(implicit + spark: SparkSession + ): Job[A] = new Job[A] { def run(): A = a } - implicit val framelessSparkDelayForJob: SparkDelay[Job] = new SparkDelay[Job] { - def delay[A](a: => A)(implicit spark: SparkSession): Job[A] = Job(a) - } + implicit val framelessSparkDelayForJob: SparkDelay[Job] = + new SparkDelay[Job] { + def delay[A]( + a: => A + )(implicit + spark: SparkSession + ): Job[A] = Job(a) + } } diff --git a/dataset/src/main/scala/frameless/RecordEncoder.scala b/dataset/src/main/scala/frameless/RecordEncoder.scala index 7427d9de0..269a4879a 100644 --- a/dataset/src/main/scala/frameless/RecordEncoder.scala +++ b/dataset/src/main/scala/frameless/RecordEncoder.scala @@ -4,7 +4,10 @@ import org.apache.spark.sql.FramelessInternals import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.{ - Invoke, NewInstance, UnwrapOption, WrapOption + Invoke, + NewInstance, + UnwrapOption, + WrapOption } import org.apache.spark.sql.types._ @@ -16,10 +19,9 @@ import shapeless.ops.record.Keys import scala.reflect.ClassTag case class RecordEncoderField( - ordinal: Int, - name: String, - encoder: TypedEncoder[_] -) + ordinal: Int, + name: String, + encoder: TypedEncoder[_]) trait RecordEncoderFields[T <: HList] extends Serializable { def value: List[RecordEncoderField] @@ -30,33 +32,40 @@ trait RecordEncoderFields[T <: HList] extends Serializable { object RecordEncoderFields { - implicit def deriveRecordLast[K <: Symbol, H] - (implicit + implicit def deriveRecordLast[K <: Symbol, H]( + implicit key: Witness.Aux[K], head: RecordFieldEncoder[H] - ): RecordEncoderFields[FieldType[K, H] :: HNil] = new RecordEncoderFields[FieldType[K, H] :: HNil] { + ): RecordEncoderFields[FieldType[K, H] :: HNil] = + new RecordEncoderFields[FieldType[K, H] :: HNil] { def value: List[RecordEncoderField] = fieldEncoder[K, H] :: Nil } - implicit def deriveRecordCons[K <: Symbol, H, T <: HList] - (implicit + implicit def deriveRecordCons[K <: Symbol, H, T <: HList]( + implicit key: Witness.Aux[K], head: RecordFieldEncoder[H], tail: RecordEncoderFields[T] - ): RecordEncoderFields[FieldType[K, H] :: T] = new RecordEncoderFields[FieldType[K, H] :: T] { + ): RecordEncoderFields[FieldType[K, H] :: T] = + new RecordEncoderFields[FieldType[K, H] :: T] { def value: List[RecordEncoderField] = - fieldEncoder[K, H] :: tail.value.map(x => x.copy(ordinal = x.ordinal + 1)) - } + fieldEncoder[K, H] :: tail.value + .map(x => x.copy(ordinal = x.ordinal + 1)) + } - private def fieldEncoder[K <: Symbol, H](implicit key: Witness.Aux[K], e: RecordFieldEncoder[H]): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder) + private def fieldEncoder[K <: Symbol, H]( + implicit + key: Witness.Aux[K], + e: RecordFieldEncoder[H] + ): RecordEncoderField = RecordEncoderField(0, key.value.name, e.encoder) } /** - * Assists the generation of constructor call parameters from a labelled generic representation. - * As Unit typed fields were removed earlier, we need to put back unit literals in the appropriate positions. - * - * @tparam T labelled generic representation of type fields - */ + * Assists the generation of constructor call parameters from a labelled generic representation. + * As Unit typed fields were removed earlier, we need to put back unit literals in the appropriate positions. + * + * @tparam T labelled generic representation of type fields + */ trait NewInstanceExprs[T <: HList] extends Serializable { def from(exprs: List[Expression]): Seq[Expression] } @@ -67,32 +76,41 @@ object NewInstanceExprs { def from(exprs: List[Expression]): Seq[Expression] = Nil } - implicit def deriveUnit[K <: Symbol, T <: HList] - (implicit + implicit def deriveUnit[K <: Symbol, T <: HList]( + implicit tail: NewInstanceExprs[T] - ): NewInstanceExprs[FieldType[K, Unit] :: T] = new NewInstanceExprs[FieldType[K, Unit] :: T] { + ): NewInstanceExprs[FieldType[K, Unit] :: T] = + new NewInstanceExprs[FieldType[K, Unit] :: T] { def from(exprs: List[Expression]): Seq[Expression] = Literal.fromObject(()) +: tail.from(exprs) } - implicit def deriveNonUnit[K <: Symbol, V, T <: HList] - (implicit + implicit def deriveNonUnit[K <: Symbol, V, T <: HList]( + implicit notUnit: V =:!= Unit, tail: NewInstanceExprs[T] - ): NewInstanceExprs[FieldType[K, V] :: T] = new NewInstanceExprs[FieldType[K, V] :: T] { - def from(exprs: List[Expression]): Seq[Expression] = exprs.head +: tail.from(exprs.tail) + ): NewInstanceExprs[FieldType[K, V] :: T] = + new NewInstanceExprs[FieldType[K, V] :: T] { + def from(exprs: List[Expression]): Seq[Expression] = + exprs.head +: tail.from(exprs.tail) } } /** - * Drops fields with Unit type from labelled generic representation of types. - * - * @tparam L labelled generic representation of type fields - */ -trait DropUnitValues[L <: HList] extends DepFn1[L] with Serializable { type Out <: HList } + * Drops fields with Unit type from labelled generic representation of types. + * + * @tparam L labelled generic representation of type fields + */ +trait DropUnitValues[L <: HList] extends DepFn1[L] with Serializable { + type Out <: HList +} object DropUnitValues { - def apply[L <: HList](implicit dropUnitValues: DropUnitValues[L]): Aux[L, dropUnitValues.Out] = dropUnitValues + + def apply[L <: HList]( + implicit + dropUnitValues: DropUnitValues[L] + ): Aux[L, dropUnitValues.Out] = dropUnitValues type Aux[L <: HList, Out0 <: HList] = DropUnitValues[L] { type Out = Out0 } @@ -101,93 +119,94 @@ object DropUnitValues { def apply(l: HNil): Out = HNil } - implicit def deriveUnit[K <: Symbol, T <: HList, OutT <: HList] - (implicit - dropUnitValues : DropUnitValues.Aux[T, OutT] - ): Aux[FieldType[K, Unit] :: T, OutT] = new DropUnitValues[FieldType[K, Unit] :: T] { + implicit def deriveUnit[K <: Symbol, T <: HList, OutT <: HList]( + implicit + dropUnitValues: DropUnitValues.Aux[T, OutT] + ): Aux[FieldType[K, Unit] :: T, OutT] = + new DropUnitValues[FieldType[K, Unit] :: T] { type Out = OutT - def apply(l : FieldType[K, Unit] :: T): Out = dropUnitValues(l.tail) + def apply(l: FieldType[K, Unit] :: T): Out = dropUnitValues(l.tail) } - implicit def deriveNonUnit[K <: Symbol, V, T <: HList, OutH, OutT <: HList] - (implicit + implicit def deriveNonUnit[K <: Symbol, V, T <: HList, OutH, OutT <: HList]( + implicit nonUnit: V =:!= Unit, - dropUnitValues : DropUnitValues.Aux[T, OutT] - ): Aux[FieldType[K, V] :: T, FieldType[K, V] :: OutT] = new DropUnitValues[FieldType[K, V] :: T] { + dropUnitValues: DropUnitValues.Aux[T, OutT] + ): Aux[FieldType[K, V] :: T, FieldType[K, V] :: OutT] = + new DropUnitValues[FieldType[K, V] :: T] { type Out = FieldType[K, V] :: OutT - def apply(l : FieldType[K, V] :: T): Out = l.head :: dropUnitValues(l.tail) + def apply(l: FieldType[K, V] :: T): Out = l.head :: dropUnitValues(l.tail) } } -class RecordEncoder[F, G <: HList, H <: HList] - (implicit +class RecordEncoder[F, G <: HList, H <: HList]( + implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], i2: IsHCons[H], fields: Lazy[RecordEncoderFields[H]], newInstanceExprs: Lazy[NewInstanceExprs[G]], - classTag: ClassTag[F] - ) extends TypedEncoder[F] { - def nullable: Boolean = false - - def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] - - def catalystRepr: DataType = { - val structFields = fields.value.value.map { field => - StructField( - name = field.name, - dataType = field.encoder.catalystRepr, - nullable = field.encoder.nullable, - metadata = Metadata.empty - ) - } - - StructType(structFields) + classTag: ClassTag[F]) + extends TypedEncoder[F] { + def nullable: Boolean = false + + def jvmRepr: DataType = FramelessInternals.objectTypeFor[F] + + def catalystRepr: DataType = { + val structFields = fields.value.value.map { field => + StructField( + name = field.name, + dataType = field.encoder.catalystRepr, + nullable = field.encoder.nullable, + metadata = Metadata.empty + ) } - def toCatalyst(path: Expression): Expression = { - val nameExprs = fields.value.value.map { field => - Literal(field.name) - } - - val valueExprs = fields.value.value.map { field => - val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil) - field.encoder.toCatalyst(fieldPath) - } - - // the way exprs are encoded in CreateNamedStruct - val exprs = nameExprs.zip(valueExprs).flatMap { - case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil - } + StructType(structFields) + } - val createExpr = CreateNamedStruct(exprs) - val nullExpr = Literal.create(null, createExpr.dataType) + def toCatalyst(path: Expression): Expression = { + val nameExprs = fields.value.value.map { field => Literal(field.name) } - If(IsNull(path), nullExpr, createExpr) + val valueExprs = fields.value.value.map { field => + val fieldPath = Invoke(path, field.name, field.encoder.jvmRepr, Nil) + field.encoder.toCatalyst(fieldPath) } - def fromCatalyst(path: Expression): Expression = { - val exprs = fields.value.value.map { field => - field.encoder.fromCatalyst( - GetStructField(path, field.ordinal, Some(field.name))) - } + // the way exprs are encoded in CreateNamedStruct + val exprs = nameExprs.zip(valueExprs).flatMap { + case (nameExpr, valueExpr) => nameExpr :: valueExpr :: Nil + } - val newArgs = newInstanceExprs.value.from(exprs) - val newExpr = NewInstance( - classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) + val createExpr = CreateNamedStruct(exprs) + val nullExpr = Literal.create(null, createExpr.dataType) - val nullExpr = Literal.create(null, jvmRepr) + If(IsNull(path), nullExpr, createExpr) + } - If(IsNull(path), nullExpr, newExpr) + def fromCatalyst(path: Expression): Expression = { + val exprs = fields.value.value.map { field => + field.encoder.fromCatalyst( + GetStructField(path, field.ordinal, Some(field.name)) + ) } + + val newArgs = newInstanceExprs.value.from(exprs) + val newExpr = + NewInstance(classTag.runtimeClass, newArgs, jvmRepr, propagateNull = true) + + val nullExpr = Literal.create(null, jvmRepr) + + If(IsNull(path), nullExpr, newExpr) + } } final class RecordFieldEncoder[T]( - val encoder: TypedEncoder[T], - private[frameless] val jvmRepr: DataType, - private[frameless] val fromCatalyst: Expression => Expression, - private[frameless] val toCatalyst: Expression => Expression -) extends Serializable + val encoder: TypedEncoder[T], + private[frameless] val jvmRepr: DataType, + private[frameless] val fromCatalyst: Expression => Expression, + private[frameless] val toCatalyst: Expression => Expression) + extends Serializable object RecordFieldEncoder extends RecordFieldEncoderLowPriority { @@ -198,8 +217,14 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { * @tparam K the key type for the fields * @tparam V the inner value type */ - implicit def optionValueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] - (implicit + implicit def optionValueClass[ + F: IsValueClass, + G <: ::[_, HNil], + H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], + K <: Symbol, + V, + KS <: ::[_ <: Symbol, HNil] + ](implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil], @@ -208,49 +233,49 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { i5: TypedEncoder[V], i6: ClassTag[F] ): RecordFieldEncoder[Option[F]] = { - val fieldName = i4.head(i3()).name - val innerJvmRepr = ObjectType(i6.runtimeClass) + val fieldName = i4.head(i3()).name + val innerJvmRepr = ObjectType(i6.runtimeClass) - val catalyst: Expression => Expression = { path => - val value = UnwrapOption(innerJvmRepr, path) - val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil) + val catalyst: Expression => Expression = { path => + val value = UnwrapOption(innerJvmRepr, path) + val javaValue = Invoke(value, fieldName, i5.jvmRepr, Nil) - i5.toCatalyst(javaValue) - } + i5.toCatalyst(javaValue) + } - val fromCatalyst: Expression => Expression = { path => - val javaValue = i5.fromCatalyst(path) - val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) + val fromCatalyst: Expression => Expression = { path => + val javaValue = i5.fromCatalyst(path) + val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) - WrapOption(value, innerJvmRepr) - } + WrapOption(value, innerJvmRepr) + } - val jvmr = ObjectType(classOf[Option[F]]) + val jvmr = ObjectType(classOf[Option[F]]) - new RecordFieldEncoder[Option[F]]( - encoder = new TypedEncoder[Option[F]] { - val nullable = true + new RecordFieldEncoder[Option[F]]( + encoder = new TypedEncoder[Option[F]] { + val nullable = true - val jvmRepr = jvmr + val jvmRepr = jvmr - @inline def catalystRepr: DataType = i5.catalystRepr + @inline def catalystRepr: DataType = i5.catalystRepr - def fromCatalyst(path: Expression): Expression = { - val javaValue = i5.fromCatalyst(path) - val value = NewInstance( - i6.runtimeClass, Seq(javaValue), innerJvmRepr) + def fromCatalyst(path: Expression): Expression = { + val javaValue = i5.fromCatalyst(path) + val value = NewInstance(i6.runtimeClass, Seq(javaValue), innerJvmRepr) - WrapOption(value, innerJvmRepr) - } + WrapOption(value, innerJvmRepr) + } - def toCatalyst(path: Expression): Expression = catalyst(path) + def toCatalyst(path: Expression): Expression = catalyst(path) - override def toString: String = s"RecordFieldEncoder.optionValueClass[${i6.runtimeClass.getName}]('${fieldName}', $i5)" - }, - jvmRepr = jvmr, - fromCatalyst = fromCatalyst, - toCatalyst = catalyst - ) + override def toString: String = + s"RecordFieldEncoder.optionValueClass[${i6.runtimeClass.getName}]('${fieldName}', $i5)" + }, + jvmRepr = jvmr, + fromCatalyst = fromCatalyst, + toCatalyst = catalyst + ) } /** @@ -259,8 +284,14 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { * @tparam H the single field of the value class (with guarantee it's not a `Unit` value) * @tparam V the inner value type */ - implicit def valueClass[F : IsValueClass, G <: ::[_, HNil], H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], K <: Symbol, V, KS <: ::[_ <: Symbol, HNil]] - (implicit + implicit def valueClass[ + F: IsValueClass, + G <: ::[_, HNil], + H <: ::[_ <: FieldType[_ <: Symbol, _], HNil], + K <: Symbol, + V, + KS <: ::[_ <: Symbol, HNil] + ](implicit i0: LabelledGeneric.Aux[F, G], i1: DropUnitValues.Aux[G, H], i2: IsHCons.Aux[H, _ <: FieldType[K, V], HNil], @@ -269,40 +300,47 @@ object RecordFieldEncoder extends RecordFieldEncoderLowPriority { i5: TypedEncoder[V], i6: ClassTag[F] ): RecordFieldEncoder[F] = { - val cls = i6.runtimeClass - val jvmr = i5.jvmRepr - val fieldName = i4.head(i3()).name - - new RecordFieldEncoder[F]( - encoder = new TypedEncoder[F] { - def nullable = i5.nullable - - def jvmRepr = jvmr - - def catalystRepr: DataType = i5.catalystRepr - - def fromCatalyst(path: Expression): Expression = - i5.fromCatalyst(path) - - @inline def toCatalyst(path: Expression): Expression = - i5.toCatalyst(path) - - override def toString: String = s"RecordFieldEncoder.valueClass[${cls.getName}]('${fieldName}', ${i5})" - }, - jvmRepr = FramelessInternals.objectTypeFor[F], - fromCatalyst = { expr: Expression => - NewInstance( - i6.runtimeClass, - i5.fromCatalyst(expr) :: Nil, - ObjectType(i6.runtimeClass)) - }, - toCatalyst = { expr: Expression => - i5.toCatalyst(Invoke(expr, fieldName, jvmr)) - } - ) + val cls = i6.runtimeClass + val jvmr = i5.jvmRepr + val fieldName = i4.head(i3()).name + + new RecordFieldEncoder[F]( + encoder = new TypedEncoder[F] { + def nullable = i5.nullable + + def jvmRepr = jvmr + + def catalystRepr: DataType = i5.catalystRepr + + def fromCatalyst(path: Expression): Expression = + i5.fromCatalyst(path) + + @inline def toCatalyst(path: Expression): Expression = + i5.toCatalyst(path) + + override def toString: String = + s"RecordFieldEncoder.valueClass[${cls.getName}]('${fieldName}', ${i5})" + }, + jvmRepr = FramelessInternals.objectTypeFor[F], + fromCatalyst = { expr: Expression => + NewInstance( + i6.runtimeClass, + i5.fromCatalyst(expr) :: Nil, + ObjectType(i6.runtimeClass) + ) + }, + toCatalyst = { expr: Expression => + i5.toCatalyst(Invoke(expr, fieldName, jvmr)) + } + ) } } private[frameless] sealed trait RecordFieldEncoderLowPriority { - implicit def apply[T](implicit e: TypedEncoder[T]): RecordFieldEncoder[T] = new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst) + + implicit def apply[T]( + implicit + e: TypedEncoder[T] + ): RecordFieldEncoder[T] = + new RecordFieldEncoder[T](e, e.jvmRepr, e.fromCatalyst, e.toCatalyst) } diff --git a/dataset/src/main/scala/frameless/SparkDelay.scala b/dataset/src/main/scala/frameless/SparkDelay.scala index 74a651ae3..83e78d3c3 100644 --- a/dataset/src/main/scala/frameless/SparkDelay.scala +++ b/dataset/src/main/scala/frameless/SparkDelay.scala @@ -3,5 +3,10 @@ package frameless import org.apache.spark.sql.SparkSession trait SparkDelay[F[_]] { - def delay[A](a: => A)(implicit spark: SparkSession): F[A] + + def delay[A]( + a: => A + )(implicit + spark: SparkSession + ): F[A] } diff --git a/dataset/src/main/scala/frameless/TypedColumn.scala b/dataset/src/main/scala/frameless/TypedColumn.scala index 0bbaf6fed..ef3fe52b0 100644 --- a/dataset/src/main/scala/frameless/TypedColumn.scala +++ b/dataset/src/main/scala/frameless/TypedColumn.scala @@ -1,11 +1,11 @@ package frameless -import frameless.functions.{litAggr, lit => flit} +import frameless.functions.{ litAggr, lit => flit } import frameless.syntax._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{Column, FramelessInternals} +import org.apache.spark.sql.{ Column, FramelessInternals } import shapeless._ import shapeless.ops.record.Selector @@ -21,91 +21,121 @@ sealed trait UntypedExpression[T] { override def toString: String = expr.toString() } -/** Expression used in `select`-like constructions. - */ -sealed class TypedColumn[T, U](expr: Expression)( - implicit val uenc: TypedEncoder[U] -) extends AbstractTypedColumn[T, U](expr) { +/** + * Expression used in `select`-like constructions. + */ +sealed class TypedColumn[T, U]( + expr: Expression + )(implicit + val uenc: TypedEncoder[U]) + extends AbstractTypedColumn[T, U](expr) { type ThisType[A, B] = TypedColumn[A, B] - def this(column: Column)(implicit uencoder: TypedEncoder[U]) = + def this( + column: Column + )(implicit + uencoder: TypedEncoder[U] + ) = this(FramelessInternals.expr(column)) - override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = c.typedColumn + override def typed[W, U1: TypedEncoder](c: Column): TypedColumn[W, U1] = + c.typedColumn override def lit[U1: TypedEncoder](c: U1): TypedColumn[T, U1] = flit(c) } -/** Expression used in `agg`-like constructions. - */ -sealed class TypedAggregate[T, U](expr: Expression)( - implicit val uenc: TypedEncoder[U] -) extends AbstractTypedColumn[T, U](expr) { +/** + * Expression used in `agg`-like constructions. + */ +sealed class TypedAggregate[T, U]( + expr: Expression + )(implicit + val uenc: TypedEncoder[U]) + extends AbstractTypedColumn[T, U](expr) { type ThisType[A, B] = TypedAggregate[A, B] - def this(column: Column)(implicit uencoder: TypedEncoder[U]) = { + def this( + column: Column + )(implicit + uencoder: TypedEncoder[U] + ) = { this(FramelessInternals.expr(column)) } - override def typed[W, U1: TypedEncoder](c: Column): TypedAggregate[W, U1] = c.typedAggregate + override def typed[W, U1: TypedEncoder](c: Column): TypedAggregate[W, U1] = + c.typedAggregate override def lit[U1: TypedEncoder](c: U1): TypedAggregate[T, U1] = litAggr(c) } -/** Generic representation of a typed column. A typed column can either be a [[TypedAggregate]] or - * a [[frameless.TypedColumn]]. - * - * Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - * - * @tparam T phantom type representing the dataset on which this columns is - * selected. When `T = A with B` the selection is on either A or B. - * @tparam U type of column - */ -abstract class AbstractTypedColumn[T, U] - (val expr: Expression) - (implicit val uencoder: TypedEncoder[U]) +/** + * Generic representation of a typed column. A typed column can either be a [[TypedAggregate]] or + * a [[frameless.TypedColumn]]. + * + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + * + * @tparam T phantom type representing the dataset on which this columns is + * selected. When `T = A with B` the selection is on either A or B. + * @tparam U type of column + */ +abstract class AbstractTypedColumn[T, U]( + val expr: Expression + )(implicit + val uencoder: TypedEncoder[U]) extends UntypedExpression[T] { self => type ThisType[A, B] <: AbstractTypedColumn[A, B] - /** A helper class to make to simplify working with Optional fields. - * - * {{{ - * val x: TypedColumn[Option[Int]] = _ - * x.opt.map(_*2) // This only compiles if the type of x is Option[X] (in this example X is of type Int) - * }}} - * - * @note Known issue: map() will NOT work when the applied function is a udf(). - * It will compile and then throw a runtime error. - **/ + /** + * A helper class to make to simplify working with Optional fields. + * + * {{{ + * val x: TypedColumn[Option[Int]] = _ + * x.opt.map(_*2) // This only compiles if the type of x is Option[X] (in this example X is of type Int) + * }}} + * + * @note Known issue: map() will NOT work when the applied function is a udf(). + * It will compile and then throw a runtime error. + */ trait Mapper[X] { - def map[G, OutputType[_,_]](u: ThisType[T, X] => OutputType[T,G]) - (implicit - ev: OutputType[T,G] <:< AbstractTypedColumn[T, G] + + def map[G, OutputType[_, _]]( + u: ThisType[T, X] => OutputType[T, G] + )(implicit + ev: OutputType[T, G] <:< AbstractTypedColumn[T, G] ): OutputType[T, Option[G]] = { - u(self.asInstanceOf[ThisType[T, X]]).asInstanceOf[OutputType[T, Option[G]]] + u(self.asInstanceOf[ThisType[T, X]]) + .asInstanceOf[OutputType[T, Option[G]]] } } - /** Makes it easier to work with Optional columns. It returns an instance of `Mapper[X]` - * where `X` is type of the unwrapped Optional. E.g., in the case of `Option[Long]`, - * `X` is of type Long. - * - * {{{ - * val x: TypedColumn[Option[Int]] = _ - * x.opt.map(_*2) - * }}} - * */ - def opt[X](implicit x: U <:< Option[X]): Mapper[X] = new Mapper[X] {} + /** + * Makes it easier to work with Optional columns. It returns an instance of `Mapper[X]` + * where `X` is type of the unwrapped Optional. E.g., in the case of `Option[Long]`, + * `X` is of type Long. + * + * {{{ + * val x: TypedColumn[Option[Int]] = _ + * x.opt.map(_*2) + * }}} + */ + def opt[X]( + implicit + x: U <:< Option[X] + ): Mapper[X] = new Mapper[X] {} /** Fall back to an untyped Column */ def untyped: Column = new Column(expr) - private def equalsTo[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = typed { + private def equalsTo[TT, W]( + other: ThisType[TT, U] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed { if (uencoder.nullable) EqualNullSafe(self.expr, other.expr) else EqualTo(self.expr, other.expr) } @@ -120,773 +150,1125 @@ abstract class AbstractTypedColumn[T, U] /** Creates a typed column of either TypedColumn or TypedAggregate. */ def lit[U1: TypedEncoder](c: U1): ThisType[T, U1] - /** Equality test. - * {{{ - * df.filter( df.col('a) === 1 ) - * }}} - * - * apache/spark - */ + /** + * Equality test. + * {{{ + * df.filter( df.col('a) === 1 ) + * }}} + * + * apache/spark + */ def ===(u: U): ThisType[T, Boolean] = equalsTo(lit(u)) - /** Equality test. - * {{{ - * df.filter( df.col('a) === df.col('b) ) - * }}} - * - * apache/spark - */ - def ===[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Equality test. + * {{{ + * df.filter( df.col('a) === df.col('b) ) + * }}} + * + * apache/spark + */ + def ===[TT, W]( + other: ThisType[TT, U] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = equalsTo(other) - /** Inequality test. - * - * {{{ - * df.filter(df.col('a) =!= df.col('b)) - * }}} - * - * apache/spark - */ - def =!=[TT, W](other: ThisType[TT, U])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Inequality test. + * + * {{{ + * df.filter(df.col('a) =!= df.col('b)) + * }}} + * + * apache/spark + */ + def =!=[TT, W]( + other: ThisType[TT, U] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(Not(equalsTo(other).expr)) - /** Inequality test. - * - * {{{ - * df.filter(df.col('a) =!= "a") - * }}} - * - * apache/spark - */ + /** + * Inequality test. + * + * {{{ + * df.filter(df.col('a) =!= "a") + * }}} + * + * apache/spark + */ def =!=(u: U): ThisType[T, Boolean] = typed(Not(equalsTo(lit(u)).expr)) - /** True if the current expression is an Option and it's None. - * - * apache/spark - */ - def isNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] = + /** + * True if the current expression is an Option and it's None. + * + * apache/spark + */ + def isNone( + implicit + i0: U <:< Option[_] + ): ThisType[T, Boolean] = typed(IsNull(expr)) - /** True if the current expression is an Option and it's not None. - * - * apache/spark - */ - def isNotNone(implicit i0: U <:< Option[_]): ThisType[T, Boolean] = + /** + * True if the current expression is an Option and it's not None. + * + * apache/spark + */ + def isNotNone( + implicit + i0: U <:< Option[_] + ): ThisType[T, Boolean] = typed(IsNotNull(expr)) - /** True if the current expression is a fractional number and is not NaN. - * - * apache/spark - */ - def isNaN(implicit n: CatalystNaN[U]): ThisType[T, Boolean] = + /** + * True if the current expression is a fractional number and is not NaN. + * + * apache/spark + */ + def isNaN( + implicit + n: CatalystNaN[U] + ): ThisType[T, Boolean] = typed(self.untyped.isNaN) /** - * True if the value for this optional column `exists` as expected - * (see `Option.exists`). - * - * {{{ - * df.col('opt).isSome(_ === someOtherCol) - * }}} - */ - def isSome[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, false) + * True if the value for this optional column `exists` as expected + * (see `Option.exists`). + * + * {{{ + * df.col('opt).isSome(_ === someOtherCol) + * }}} + */ + def isSome[V]( + exists: ThisType[T, V] => ThisType[T, Boolean] + )(implicit + i0: U <:< Option[V] + ): ThisType[T, Boolean] = someOr[V](exists, false) /** - * True if the value for this optional column `exists` as expected, - * or is `None`. (see `Option.forall`). - * - * {{{ - * df.col('opt).isSomeOrNone(_ === someOtherCol) - * }}} - */ - def isSomeOrNone[V](exists: ThisType[T, V] => ThisType[T, Boolean])(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = someOr[V](exists, true) - - private def someOr[V](exists: ThisType[T, V] => ThisType[T, Boolean], default: Boolean)(implicit i0: U <:< Option[V]): ThisType[T, Boolean] = { + * True if the value for this optional column `exists` as expected, + * or is `None`. (see `Option.forall`). + * + * {{{ + * df.col('opt).isSomeOrNone(_ === someOtherCol) + * }}} + */ + def isSomeOrNone[V]( + exists: ThisType[T, V] => ThisType[T, Boolean] + )(implicit + i0: U <:< Option[V] + ): ThisType[T, Boolean] = someOr[V](exists, true) + + private def someOr[V]( + exists: ThisType[T, V] => ThisType[T, Boolean], + default: Boolean + )(implicit + i0: U <:< Option[V] + ): ThisType[T, Boolean] = { val defaultExpr = if (default) Literal.TrueLiteral else Literal.FalseLiteral typed(Coalesce(Seq(opt(i0).map(exists).expr, defaultExpr))) } - /** Convert an Optional column by providing a default value. - * - * {{{ - * df(df('opt).getOrElse(df('defaultValue))) - * }}} - */ - def getOrElse[TT, W, Out](default: ThisType[TT, Out])(implicit i0: U =:= Option[Out], i1: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Convert an Optional column by providing a default value. + * + * {{{ + * df(df('opt).getOrElse(df('defaultValue))) + * }}} + */ + def getOrElse[TT, W, Out]( + default: ThisType[TT, Out] + )(implicit + i0: U =:= Option[Out], + i1: With.Aux[T, TT, W] + ): ThisType[W, Out] = typed(Coalesce(Seq(expr, default.expr)))(default.uencoder) - /** Convert an Optional column by providing a default value. - * - * {{{ - * df( df('opt).getOrElse(defaultConstant) ) - * }}} - */ - def getOrElse[Out: TypedEncoder](default: Out)(implicit i0: U =:= Option[Out]): ThisType[T, Out] = + /** + * Convert an Optional column by providing a default value. + * + * {{{ + * df( df('opt).getOrElse(defaultConstant) ) + * }}} + */ + def getOrElse[Out: TypedEncoder]( + default: Out + )(implicit + i0: U =:= Option[Out] + ): ThisType[T, Out] = getOrElse(lit[Out](default)) - /** Sum of this expression and another expression. - * - * {{{ - * // The following selects the sum of a person's height and weight. - * people.select( people.col('height) plus people.col('weight) ) - * }}} - * - * apache/spark - */ - def plus[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Sum of this expression and another expression. + * + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people.col('height) plus people.col('weight) ) + * }}} + * + * apache/spark + */ + def plus[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.plus(other.untyped)) - /** Sum of this expression and another expression. - * {{{ - * // The following selects the sum of a person's height and weight. - * people.select( people.col('height) + people.col('weight) ) - * }}} - * - * apache/spark - */ - def +[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people.col('height) + people.col('weight) ) + * }}} + * + * apache/spark + */ + def +[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = plus(other) - /** Sum of this expression (column) with a constant. - * {{{ - * // The following selects the sum of a person's height and weight. - * people.select( people('height) + 2 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def +(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Sum of this expression (column) with a constant. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people('height) + 2 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def +( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.plus(u)) /** - * Inversion of boolean expression, i.e. NOT. - * {{{ - * // Select rows that are not active (isActive === false) - * df.filter( !df('isActive) ) - * }}} - * - * apache/spark - */ - def unary_!(implicit i0: U <:< Boolean): ThisType[T, Boolean] = + * Inversion of boolean expression, i.e. NOT. + * {{{ + * // Select rows that are not active (isActive === false) + * df.filter( !df('isActive) ) + * }}} + * + * apache/spark + */ + def unary_!( + implicit + i0: U <:< Boolean + ): ThisType[T, Boolean] = typed(!untyped) - /** Unary minus, i.e. negate the expression. - * {{{ - * // Select the amount column and negates all values. - * df.select( -df('amount) ) - * }}} - * - * apache/spark - */ - def unary_-(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * df.select( -df('amount) ) + * }}} + * + * apache/spark + */ + def unary_-( + implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(-self.untyped) - /** Subtraction. Subtract the other expression from this expression. - * {{{ - * // The following selects the difference between people's height and their weight. - * people.select( people.col('height) minus people.col('weight) ) - * }}} - * - * apache/spark - */ - def minus[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people.col('height) minus people.col('weight) ) + * }}} + * + * apache/spark + */ + def minus[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.minus(other.untyped)) - /** Subtraction. Subtract the other expression from this expression. - * {{{ - * // The following selects the difference between people's height and their weight. - * people.select( people.col('height) - people.col('weight) ) - * }}} - * - * apache/spark - */ - def -[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people.col('height) - people.col('weight) ) + * }}} + * + * apache/spark + */ + def -[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = minus(other) - /** Subtraction. Subtract the other expression from this expression. - * {{{ - * // The following selects the difference between people's height and their weight. - * people.select( people('height) - 1 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def -(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people('height) - 1 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def -( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.minus(u)) - /** Multiplication of this expression and another expression. - * {{{ - * // The following multiplies a person's height by their weight. - * people.select( people.col('height) multiply people.col('weight) ) - * }}} - * - * apache/spark - */ - def multiply[TT, W] - (other: ThisType[TT, U]) - (implicit + /** + * Multiplication of this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people.col('height) multiply people.col('weight) ) + * }}} + * + * apache/spark + */ + def multiply[TT, W]( + other: ThisType[TT, U] + )(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W], t: ClassTag[U] ): ThisType[W, U] = typed { - if (t.runtimeClass == BigDecimal(0).getClass) { - // That's apparently the only way to get sound multiplication. - // See https://issues.apache.org/jira/browse/SPARK-22036 - val dt = DecimalType(20, 14) - self.untyped.cast(dt).multiply(other.untyped.cast(dt)) - } else { - self.untyped.multiply(other.untyped) - } + if (t.runtimeClass == BigDecimal(0).getClass) { + // That's apparently the only way to get sound multiplication. + // See https://issues.apache.org/jira/browse/SPARK-22036 + val dt = DecimalType(20, 14) + self.untyped.cast(dt).multiply(other.untyped.cast(dt)) + } else { + self.untyped.multiply(other.untyped) } + } - /** Multiplication of this expression and another expression. - * {{{ - * // The following multiplies a person's height by their weight. - * people.select( people.col('height) * people.col('weight) ) - * }}} - * - * apache/spark - */ - def *[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W], t: ClassTag[U]): ThisType[W, U] = + /** + * Multiplication of this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people.col('height) * people.col('weight) ) + * }}} + * + * apache/spark + */ + def *[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W], + t: ClassTag[U] + ): ThisType[W, U] = multiply(other) - /** Multiplication of this expression a constant. - * {{{ - * // The following multiplies a person's height by their weight. - * people.select( people.col('height) * people.col('weight) ) - * }}} - * - * apache/spark - */ - def *(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Multiplication of this expression a constant. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people.col('height) * people.col('weight) ) + * }}} + * + * apache/spark + */ + def *( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.multiply(u)) - /** Modulo (a.k.a. remainder) expression. - * - * apache/spark - */ - def mod[Out: TypedEncoder, TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Modulo (a.k.a. remainder) expression. + * + * apache/spark + */ + def mod[Out: TypedEncoder, TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Out] = typed(self.untyped.mod(other.untyped)) - /** Modulo (a.k.a. remainder) expression. - * - * apache/spark - */ - def %[TT, W](other: ThisType[TT, U])(implicit n: CatalystNumeric[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Modulo (a.k.a. remainder) expression. + * + * apache/spark + */ + def %[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystNumeric[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = mod(other) - /** Modulo (a.k.a. remainder) expression. - * - * apache/spark - */ - def %(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, U] = + /** + * Modulo (a.k.a. remainder) expression. + * + * apache/spark + */ + def %( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, U] = typed(self.untyped.mod(u)) - /** Division this expression by another expression. - * {{{ - * // The following divides a person's height by their weight. - * people.select( people('height) / people('weight) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def divide[Out: TypedEncoder, TT, W](other: ThisType[TT, U])(implicit n: CatalystDivisible[U, Out], w: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people('height) / people('weight) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def divide[Out: TypedEncoder, TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystDivisible[U, Out], + w: With.Aux[T, TT, W] + ): ThisType[W, Out] = typed(self.untyped.divide(other.untyped)) - /** Division this expression by another expression. - * {{{ - * // The following divides a person's height by their weight. - * people.select( people('height) / people('weight) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def /[Out, TT, W](other: ThisType[TT, U])(implicit n: CatalystDivisible[U, Out], e: TypedEncoder[Out], w: With.Aux[T, TT, W]): ThisType[W, Out] = + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people('height) / people('weight) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def /[Out, TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystDivisible[U, Out], + e: TypedEncoder[Out], + w: With.Aux[T, TT, W] + ): ThisType[W, Out] = divide(other) - /** Division this expression by another expression. - * {{{ - * // The following divides a person's height by their weight. - * people.select( people('height) / 2 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def /(u: U)(implicit n: CatalystNumeric[U]): ThisType[T, Double] = + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people('height) / 2 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def /( + u: U + )(implicit + n: CatalystNumeric[U] + ): ThisType[T, Double] = typed(self.untyped.divide(u)) - /** Returns a descending ordering used in sorting - * - * apache/spark - */ - def desc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] = + /** + * Returns a descending ordering used in sorting + * + * apache/spark + */ + def desc( + implicit + catalystOrdered: CatalystOrdered[U] + ): SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](untyped.desc) - /** Returns an ascending ordering used in sorting - * - * apache/spark - */ - def asc(implicit catalystOrdered: CatalystOrdered[U]): SortedTypedColumn[T, U] = + /** + * Returns an ascending ordering used in sorting + * + * apache/spark + */ + def asc( + implicit + catalystOrdered: CatalystOrdered[U] + ): SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](untyped.asc) - /** Bitwise AND this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseAND (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseAND(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise AND this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseAND (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseAND( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = typed(self.untyped.bitwiseAND(u)) - /** Bitwise AND this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseAND (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseAND[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise AND this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseAND (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseAND[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.bitwiseAND(other.untyped)) - /** Bitwise AND this expression and another expression (of same type). - * {{{ - * df.select(df.col('colA).cast[Int] & -1) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def &(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise AND this expression and another expression (of same type). + * {{{ + * df.select(df.col('colA).cast[Int] & -1) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def &( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = bitwiseAND(u) - /** Bitwise AND this expression and another expression. - * {{{ - * df.select(df.col('colA) & (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def &[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise AND this expression and another expression. + * {{{ + * df.select(df.col('colA) & (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def &[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = bitwiseAND(other) - /** Bitwise OR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseOR (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise OR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseOR (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseOR( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = typed(self.untyped.bitwiseOR(u)) - /** Bitwise OR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseOR (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def bitwiseOR[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise OR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseOR (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def bitwiseOR[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.bitwiseOR(other.untyped)) - /** Bitwise OR this expression and another expression (of same type). - * {{{ - * df.select(df.col('colA).cast[Long] | 1L) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def |(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise OR this expression and another expression (of same type). + * {{{ + * df.select(df.col('colA).cast[Long] | 1L) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def |( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = bitwiseOR(u) - /** Bitwise OR this expression and another expression. - * {{{ - * df.select(df.col('colA) | (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def |[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise OR this expression and another expression. + * {{{ + * df.select(df.col('colA) | (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def |[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = bitwiseOR(other) - /** Bitwise XOR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseXOR (df.col('colB))) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def bitwiseXOR(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise XOR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseXOR (df.col('colB))) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def bitwiseXOR( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = typed(self.untyped.bitwiseXOR(u)) - /** Bitwise XOR this expression and another expression. - * {{{ - * df.select(df.col('colA) bitwiseXOR (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def bitwiseXOR[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise XOR this expression and another expression. + * {{{ + * df.select(df.col('colA) bitwiseXOR (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def bitwiseXOR[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = typed(self.untyped.bitwiseXOR(other.untyped)) - /** Bitwise XOR this expression and another expression (of same type). - * {{{ - * df.select(df.col('colA).cast[Long] ^ 1L) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def ^(u: U)(implicit n: CatalystBitwise[U]): ThisType[T, U] = + /** + * Bitwise XOR this expression and another expression (of same type). + * {{{ + * df.select(df.col('colA).cast[Long] ^ 1L) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def ^( + u: U + )(implicit + n: CatalystBitwise[U] + ): ThisType[T, U] = bitwiseXOR(u) - /** Bitwise XOR this expression and another expression. - * {{{ - * df.select(df.col('colA) ^ (df.col('colB))) - * }}} - * - * @param other a constant of the same type - * apache/spark - */ - def ^[TT, W](other: ThisType[TT, U])(implicit n: CatalystBitwise[U], w: With.Aux[T, TT, W]): ThisType[W, U] = + /** + * Bitwise XOR this expression and another expression. + * {{{ + * df.select(df.col('colA) ^ (df.col('colB))) + * }}} + * + * @param other a constant of the same type + * apache/spark + */ + def ^[TT, W]( + other: ThisType[TT, U] + )(implicit + n: CatalystBitwise[U], + w: With.Aux[T, TT, W] + ): ThisType[W, U] = bitwiseXOR(other) - /** Casts the column to a different type. - * {{{ - * df.select(df('a).cast[Int]) - * }}} - */ - def cast[A: TypedEncoder](implicit c: CatalystCast[U, A]): ThisType[T, A] = + /** + * Casts the column to a different type. + * {{{ + * df.select(df('a).cast[Int]) + * }}} + */ + def cast[A: TypedEncoder]( + implicit + c: CatalystCast[U, A] + ): ThisType[T, A] = typed(self.untyped.cast(TypedEncoder[A].catalystRepr)) /** - * An expression that returns a substring - * {{{ - * df.select(df('a).substr(0, 5)) - * }}} - * - * @param startPos starting position - * @param len length of the substring - */ - def substr(startPos: Int, len: Int)(implicit ev: U =:= String): ThisType[T, String] = + * An expression that returns a substring + * {{{ + * df.select(df('a).substr(0, 5)) + * }}} + * + * @param startPos starting position + * @param len length of the substring + */ + def substr( + startPos: Int, + len: Int + )(implicit + ev: U =:= String + ): ThisType[T, String] = typed(self.untyped.substr(startPos, len)) /** - * An expression that returns a substring - * {{{ - * df.select(df('a).substr(df('b), df('c))) - * }}} - * - * @param startPos expression for the starting position - * @param len expression for the length of the substring - */ - def substr[TT1, TT2, W1, W2](startPos: ThisType[TT1, Int], len: ThisType[TT2, Int]) - (implicit - ev: U =:= String, - w1: With.Aux[T, TT1, W1], - w2: With.Aux[W1, TT2, W2]): ThisType[W2, String] = + * An expression that returns a substring + * {{{ + * df.select(df('a).substr(df('b), df('c))) + * }}} + * + * @param startPos expression for the starting position + * @param len expression for the length of the substring + */ + def substr[TT1, TT2, W1, W2]( + startPos: ThisType[TT1, Int], + len: ThisType[TT2, Int] + )(implicit + ev: U =:= String, + w1: With.Aux[T, TT1, W1], + w2: With.Aux[W1, TT2, W2] + ): ThisType[W2, String] = typed(self.untyped.substr(startPos.untyped, len.untyped)) - /** SQL like expression. Returns a boolean column based on a SQL LIKE match. - * {{{ - * val ds = TypedDataset.create(X2("foo", "bar") :: Nil) - * // true - * ds.select(ds('a).like("foo")) - * - * // Selected column has value "bar" - * ds.select(when(ds('a).like("f"), ds('a)).otherwise(ds('b)) - * }}} - * apache/spark - */ - def like(literal: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * SQL like expression. Returns a boolean column based on a SQL LIKE match. + * {{{ + * val ds = TypedDataset.create(X2("foo", "bar") :: Nil) + * // true + * ds.select(ds('a).like("foo")) + * + * // Selected column has value "bar" + * ds.select(when(ds('a).like("f"), ds('a)).otherwise(ds('b)) + * }}} + * apache/spark + */ + def like( + literal: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.like(literal)) - /** SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match. - * {{{ - * val ds = TypedDataset.create(X1("foo") :: Nil) - * // true - * ds.select(ds('a).rlike("foo")) - * - * // true - * ds.select(ds('a).rlike(".*)) - * }}} - * apache/spark - */ - def rlike(literal: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * SQL RLIKE expression (LIKE with Regex). Returns a boolean column based on a regex match. + * {{{ + * val ds = TypedDataset.create(X1("foo") :: Nil) + * // true + * ds.select(ds('a).rlike("foo")) + * + * // true + * ds.select(ds('a).rlike(".*)) + * }}} + * apache/spark + */ + def rlike( + literal: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.rlike(literal)) - /** String contains another string literal. - * {{{ - * df.filter ( df.col('a).contains("foo") ) - * }}} - * - * @param other a string that is being tested against. - * apache/spark - */ - def contains(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * String contains another string literal. + * {{{ + * df.filter ( df.col('a).contains("foo") ) + * }}} + * + * @param other a string that is being tested against. + * apache/spark + */ + def contains( + other: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.contains(other)) - /** String contains. - * {{{ - * df.filter ( df.col('a).contains(df.col('b) ) - * }}} - * - * @param other a column which values is used as a string that is being tested against. - * apache/spark - */ - def contains[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * String contains. + * {{{ + * df.filter ( df.col('a).contains(df.col('b) ) + * }}} + * + * @param other a column which values is used as a string that is being tested against. + * apache/spark + */ + def contains[TT, W]( + other: ThisType[TT, U] + )(implicit + ev: U =:= String, + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.contains(other.untyped)) - /** String starts with another string literal. - * {{{ - * df.filter ( df.col('a).startsWith("foo") - * }}} - * - * @param other a prefix that is being tested against. - * apache/spark - */ - def startsWith(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * String starts with another string literal. + * {{{ + * df.filter ( df.col('a).startsWith("foo") + * }}} + * + * @param other a prefix that is being tested against. + * apache/spark + */ + def startsWith( + other: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.startsWith(other)) - /** String starts with. - * {{{ - * df.filter ( df.col('a).startsWith(df.col('b)) - * }}} - * - * @param other a column which values is used as a prefix that is being tested against. - * apache/spark - */ - def startsWith[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * String starts with. + * {{{ + * df.filter ( df.col('a).startsWith(df.col('b)) + * }}} + * + * @param other a column which values is used as a prefix that is being tested against. + * apache/spark + */ + def startsWith[TT, W]( + other: ThisType[TT, U] + )(implicit + ev: U =:= String, + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.startsWith(other.untyped)) - /** String ends with another string literal. - * {{{ - * df.filter ( df.col('a).endsWith("foo") - * }}} - * - * @param other a suffix that is being tested against. - * apache/spark - */ - def endsWith(other: String)(implicit ev: U =:= String): ThisType[T, Boolean] = + /** + * String ends with another string literal. + * {{{ + * df.filter ( df.col('a).endsWith("foo") + * }}} + * + * @param other a suffix that is being tested against. + * apache/spark + */ + def endsWith( + other: String + )(implicit + ev: U =:= String + ): ThisType[T, Boolean] = typed(self.untyped.endsWith(other)) - /** String ends with. - * {{{ - * df.filter ( df.col('a).endsWith(df.col('b)) - * }}} - * - * @param other a column which values is used as a suffix that is being tested against. - * apache/spark - */ - def endsWith[TT, W](other: ThisType[TT, U])(implicit ev: U =:= String, w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * String ends with. + * {{{ + * df.filter ( df.col('a).endsWith(df.col('b)) + * }}} + * + * @param other a column which values is used as a suffix that is being tested against. + * apache/spark + */ + def endsWith[TT, W]( + other: ThisType[TT, U] + )(implicit + ev: U =:= String, + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.endsWith(other.untyped)) - /** Boolean AND. - * {{{ - * df.filter ( (df.col('a) === 1).and(df.col('b) > 5) ) - * }}} - */ - def and[TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean AND. + * {{{ + * df.filter ( (df.col('a) === 1).and(df.col('b) > 5) ) + * }}} + */ + def and[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.and(other.untyped)) - /** Boolean AND. - * {{{ - * df.filter ( df.col('a) === 1 && df.col('b) > 5) - * }}} - */ - def && [TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean AND. + * {{{ + * df.filter ( df.col('a) === 1 && df.col('b) > 5) + * }}} + */ + def &&[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = and(other) - /** Boolean OR. - * {{{ - * df.filter ( (df.col('a) === 1).or(df.col('b) > 5) ) - * }}} - */ - def or[TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean OR. + * {{{ + * df.filter ( (df.col('a) === 1).or(df.col('b) > 5) ) + * }}} + */ + def or[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped.or(other.untyped)) - /** Boolean OR. - * {{{ - * df.filter ( df.col('a) === 1 || df.col('b) > 5) - * }}} - */ - def || [TT, W](other: ThisType[TT, Boolean])(implicit w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Boolean OR. + * {{{ + * df.filter ( df.col('a) === 1 || df.col('b) > 5) + * }}} + */ + def ||[TT, W]( + other: ThisType[TT, Boolean] + )(implicit + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = or(other) - /** Less than. - * - * {{{ - * // The following selects people younger than the maxAge column. - * df.select(df('age) < df('maxAge) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def <[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Less than. + * + * {{{ + * // The following selects people younger than the maxAge column. + * df.select(df('age) < df('maxAge) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def <[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped < other.untyped) - /** Less than or equal to. - * - * {{{ - * // The following selects people younger or equal than the maxAge column. - * df.select(df('age) <= df('maxAge) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def <=[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Less than or equal to. + * + * {{{ + * // The following selects people younger or equal than the maxAge column. + * df.select(df('age) <= df('maxAge) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def <=[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped <= other.untyped) - /** Greater than. - * {{{ - * // The following selects people older than the maxAge column. - * df.select( df('age) > df('maxAge) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def >[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Greater than. + * {{{ + * // The following selects people older than the maxAge column. + * df.select( df('age) > df('maxAge) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def >[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped > other.untyped) - /** Greater than or equal. - * {{{ - * // The following selects people older or equal than the maxAge column. - * df.select( df('age) >= df('maxAge) ) - * }}} - * - * @param other another column of the same type - * apache/spark - */ - def >=[TT, W](other: ThisType[TT, U])(implicit i0: CatalystOrdered[U], w: With.Aux[T, TT, W]): ThisType[W, Boolean] = + /** + * Greater than or equal. + * {{{ + * // The following selects people older or equal than the maxAge column. + * df.select( df('age) >= df('maxAge) ) + * }}} + * + * @param other another column of the same type + * apache/spark + */ + def >=[TT, W]( + other: ThisType[TT, U] + )(implicit + i0: CatalystOrdered[U], + w: With.Aux[T, TT, W] + ): ThisType[W, Boolean] = typed(self.untyped >= other.untyped) - /** Less than. - * {{{ - * // The following selects people younger than 21. - * df.select( df('age) < 21 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def <(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Less than. + * {{{ + * // The following selects people younger than 21. + * df.select( df('age) < 21 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def <( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped < lit(u)(self.uencoder).untyped) - /** Less than or equal to. - * {{{ - * // The following selects people younger than 22. - * df.select( df('age) <= 2 ) - * }}} - * - * @param u a constant of the same type - * apache/spark - */ - def <=(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Less than or equal to. + * {{{ + * // The following selects people younger than 22. + * df.select( df('age) <= 2 ) + * }}} + * + * @param u a constant of the same type + * apache/spark + */ + def <=( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped <= lit(u)(self.uencoder).untyped) - /** Greater than. - * {{{ - * // The following selects people older than 21. - * df.select( df('age) > 21 ) - * }}} - * - * @param u another column of the same type - * apache/spark - */ - def >(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Greater than. + * {{{ + * // The following selects people older than 21. + * df.select( df('age) > 21 ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped > lit(u)(self.uencoder).untyped) - /** Greater than or equal. - * {{{ - * // The following selects people older than 20. - * df.select( df('age) >= 21 ) - * }}} - * - * @param u another column of the same type - * apache/spark - */ - def >=(u: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = + /** + * Greater than or equal. + * {{{ + * // The following selects people older than 20. + * df.select( df('age) >= 21 ) + * }}} + * + * @param u another column of the same type + * apache/spark + */ + def >=( + u: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = typed(self.untyped >= lit(u)(self.uencoder).untyped) /** - * Returns true if the value of this column is contained in of the arguments. - * {{{ - * // The following selects people with age 15, 20, or 30. - * df.select( df('age).isin(15, 20, 30) ) - * }}} - * - * @param values are constants of the same type - * apache/spark - */ - def isin(values: U*)(implicit e: CatalystIsin[U]): ThisType[T, Boolean] = - typed(self.untyped.isin(values:_*)) - - /** - * True if the current column is between the lower bound and upper bound, inclusive. - * - * @param lowerBound a constant of the same type - * @param upperBound a constant of the same type - * apache/spark - */ - def between(lowerBound: U, upperBound: U)(implicit i0: CatalystOrdered[U]): ThisType[T, Boolean] = - typed(self.untyped.between(lit(lowerBound)(self.uencoder).untyped, lit(upperBound)(self.uencoder).untyped)) - - /** - * True if the current column is between the lower bound and upper bound, inclusive. - * - * @param lowerBound another column of the same type - * @param upperBound another column of the same type - * apache/spark - */ - def between[TT1, TT2, W1, W2](lowerBound: ThisType[TT1, U], upperBound: ThisType[TT2, U]) - (implicit + * Returns true if the value of this column is contained in of the arguments. + * {{{ + * // The following selects people with age 15, 20, or 30. + * df.select( df('age).isin(15, 20, 30) ) + * }}} + * + * @param values are constants of the same type + * apache/spark + */ + def isin( + values: U* + )(implicit + e: CatalystIsin[U] + ): ThisType[T, Boolean] = + typed(self.untyped.isin(values: _*)) + + /** + * True if the current column is between the lower bound and upper bound, inclusive. + * + * @param lowerBound a constant of the same type + * @param upperBound a constant of the same type + * apache/spark + */ + def between( + lowerBound: U, + upperBound: U + )(implicit + i0: CatalystOrdered[U] + ): ThisType[T, Boolean] = + typed( + self.untyped.between( + lit(lowerBound)(self.uencoder).untyped, + lit(upperBound)(self.uencoder).untyped + ) + ) + + /** + * True if the current column is between the lower bound and upper bound, inclusive. + * + * @param lowerBound another column of the same type + * @param upperBound another column of the same type + * apache/spark + */ + def between[TT1, TT2, W1, W2]( + lowerBound: ThisType[TT1, U], + upperBound: ThisType[TT2, U] + )(implicit i0: CatalystOrdered[U], w0: With.Aux[T, TT1, W1], w1: With.Aux[TT2, W1, W2] ): ThisType[W2, Boolean] = - typed(self.untyped.between(lowerBound.untyped, upperBound.untyped)) + typed(self.untyped.between(lowerBound.untyped, upperBound.untyped)) /** - * Returns a nested column matching the field `symbol`. - * - * @param symbol the field symbol - * @tparam V the type of the nested field - */ - def field[V](symbol: Witness.Lt[Symbol])(implicit + * Returns a nested column matching the field `symbol`. + * + * @param symbol the field symbol + * @tparam V the type of the nested field + */ + def field[V]( + symbol: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[U, symbol.T, V], i1: TypedEncoder[V] - ): ThisType[T, V] = + ): ThisType[T, V] = typed(self.untyped.getField(symbol.value.name)) } - -sealed class SortedTypedColumn[T, U](val expr: Expression)( - implicit - val uencoder: TypedEncoder[U] -) extends UntypedExpression[T] { - - def this(column: Column)(implicit e: TypedEncoder[U]) = { +sealed class SortedTypedColumn[T, U]( + val expr: Expression + )(implicit + val uencoder: TypedEncoder[U]) + extends UntypedExpression[T] { + + def this( + column: Column + )(implicit + e: TypedEncoder[U] + ) = { this(FramelessInternals.expr(column)) } @@ -894,16 +1276,24 @@ sealed class SortedTypedColumn[T, U](val expr: Expression)( } object SortedTypedColumn { - implicit def defaultAscending[T, U : CatalystOrdered](typedColumn: TypedColumn[T, U]): SortedTypedColumn[T, U] = + + implicit def defaultAscending[T, U: CatalystOrdered]( + typedColumn: TypedColumn[T, U] + ): SortedTypedColumn[T, U] = new SortedTypedColumn[T, U](typedColumn.untyped.asc)(typedColumn.uencoder) - object defaultAscendingPoly extends Poly1 { - implicit def caseTypedColumn[T, U : CatalystOrdered] = at[TypedColumn[T, U]](c => defaultAscending(c)) - implicit def caseTypeSortedColumn[T, U] = at[SortedTypedColumn[T, U]](identity) - } + object defaultAscendingPoly extends Poly1 { + + implicit def caseTypedColumn[T, U: CatalystOrdered] = + at[TypedColumn[T, U]](c => defaultAscending(c)) + + implicit def caseTypeSortedColumn[T, U] = + at[SortedTypedColumn[T, U]](identity) + } } object TypedColumn { + /** Evidence that type `T` has column `K` with type `V`. */ @implicitNotFound(msg = "No column ${K} of type ${V} in ${T}") trait Exists[T, K, V] @@ -912,37 +1302,46 @@ object TypedColumn { trait ExistsMany[T, K <: HList, V] object ExistsMany { - implicit def deriveCons[T, KH, KT <: HList, V0, V1] - (implicit + + implicit def deriveCons[T, KH, KT <: HList, V0, V1]( + implicit head: Exists[T, KH, V0], tail: ExistsMany[V0, KT, V1] ): ExistsMany[T, KH :: KT, V1] = - new ExistsMany[T, KH :: KT, V1] {} + new ExistsMany[T, KH :: KT, V1] {} - implicit def deriveHNil[T, K, V](implicit head: Exists[T, K, V]): ExistsMany[T, K :: HNil, V] = + implicit def deriveHNil[T, K, V]( + implicit + head: Exists[T, K, V] + ): ExistsMany[T, K :: HNil, V] = new ExistsMany[T, K :: HNil, V] {} } object Exists { - def apply[T, V](column: Witness)(implicit e: Exists[T, column.T, V]): Exists[T, column.T, V] = e - implicit def deriveRecord[T, H <: HList, K, V] - (implicit + def apply[T, V]( + column: Witness + )(implicit + e: Exists[T, column.T, V] + ): Exists[T, column.T, V] = e + + implicit def deriveRecord[T, H <: HList, K, V]( + implicit i0: LabelledGeneric.Aux[T, H], i1: Selector.Aux[H, K, V] ): Exists[T, K, V] = new Exists[T, K, V] {} } /** - * {{{ - * import frameless.TypedColumn - * - * case class Foo(id: Int, bar: String) - * - * val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar } - * val colid = TypedColumn[Foo, Int](_.id) - * }}} - */ + * {{{ + * import frameless.TypedColumn + * + * case class Foo(id: Int, bar: String) + * + * val colbar: TypedColumn[Foo, String] = TypedColumn { foo: Foo => foo.bar } + * val colid = TypedColumn[Foo, Int](_.id) + * }}} + */ def apply[T, U](x: T => U): TypedColumn[T, U] = macro TypedColumnMacroImpl.applyImpl[T, U] diff --git a/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala b/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala index 62fa2765d..7a1981b92 100644 --- a/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala +++ b/dataset/src/main/scala/frameless/TypedColumnMacroImpl.scala @@ -4,7 +4,10 @@ import scala.reflect.macros.whitebox private[frameless] object TypedColumnMacroImpl { - def applyImpl[T: c.WeakTypeTag, U: c.WeakTypeTag](c: whitebox.Context)(x: c.Tree): c.Expr[TypedColumn[T, U]] = { + def applyImpl[T: c.WeakTypeTag, U: c.WeakTypeTag]( + c: whitebox.Context + )(x: c.Tree + ): c.Expr[TypedColumn[T, U]] = { import c.universe._ val t = c.weakTypeOf[T] @@ -13,7 +16,9 @@ private[frameless] object TypedColumnMacroImpl { def buildExpression(path: List[String]): c.Expr[TypedColumn[T, U]] = { val columnName = path.mkString(".") - c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($columnName)).expr)") + c.Expr[TypedColumn[T, U]]( + q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($columnName)).expr)" + ) } def abort(msg: String) = c.abort(c.enclosingPosition, msg) @@ -48,34 +53,39 @@ private[frameless] object TypedColumnMacroImpl { } x match { - case fn: Function => fn.body match { - case select: Select if select.name.isTermName => - val expectedRoot: Option[String] = fn.vparams match { - case List(rt) if rt.rhs == EmptyTree => - Option.empty[String] - - case List(rt) => - Some(rt.toString) + case fn: Function => + fn.body match { + case select: Select if select.name.isTermName => + val expectedRoot: Option[String] = fn.vparams match { + case List(rt) if rt.rhs == EmptyTree => + Option.empty[String] + + case List(rt) => + Some(rt.toString) + + case u => + abort( + s"Select expression must have a single parameter: ${u mkString ", "}" + ) + } - case u => - abort(s"Select expression must have a single parameter: ${u mkString ", "}") - } + path(select, List.empty) match { + case root :: tail + if (expectedRoot.forall(_ == root) && check(t, tail)) => { + val colPath = tail.mkString(".") - path(select, List.empty) match { - case root :: tail if ( - expectedRoot.forall(_ == root) && check(t, tail)) => { - val colPath = tail.mkString(".") + c.Expr[TypedColumn[T, U]]( + q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($colPath)).expr)" + ) + } - c.Expr[TypedColumn[T, U]](q"new _root_.frameless.TypedColumn[$t, $u]((org.apache.spark.sql.functions.col($colPath)).expr)") + case _ => + abort(s"Invalid select expression: $select") } - case _ => - abort(s"Invalid select expression: $select") - } - - case t => - abort(s"Select expression expected: $t") - } + case t => + abort(s"Select expression expected: $t") + } case _ => abort(s"Function expected: $x") diff --git a/dataset/src/main/scala/frameless/TypedDataset.scala b/dataset/src/main/scala/frameless/TypedDataset.scala index add2170b2..28ad4fa5f 100644 --- a/dataset/src/main/scala/frameless/TypedDataset.scala +++ b/dataset/src/main/scala/frameless/TypedDataset.scala @@ -4,36 +4,58 @@ import java.util import frameless.functions.CatalystExplodableCollection import frameless.ops._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Column, DataFrame, Dataset, FramelessInternals, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint} +import org.apache.spark.sql.{ + Column, + DataFrame, + Dataset, + FramelessInternals, + SparkSession +} +import org.apache.spark.sql.catalyst.expressions.{ + Attribute, + AttributeReference, + Literal +} +import org.apache.spark.sql.catalyst.plans.logical.{ Join, JoinHint } import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.types.StructType import shapeless._ import shapeless.labelled.FieldType -import shapeless.ops.hlist.{Diff, IsHCons, Mapper, Prepend, ToTraversable, Tupler} -import shapeless.ops.record.{Keys, Modifier, Remover, Values} +import shapeless.ops.hlist.{ + Diff, + IsHCons, + Mapper, + Prepend, + ToTraversable, + Tupler +} +import shapeless.ops.record.{ Keys, Modifier, Remover, Values } import scala.language.experimental.macros -/** [[TypedDataset]] is a safer interface for working with `Dataset`. - * - * NOTE: Prefer `TypedDataset.create` over `new TypedDataset` unless you - * know what you are doing. - * - * Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - */ -class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val encoder: TypedEncoder[T]) +/** + * [[TypedDataset]] is a safer interface for working with `Dataset`. + * + * NOTE: Prefer `TypedDataset.create` over `new TypedDataset` unless you + * know what you are doing. + * + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + */ +class TypedDataset[T] protected[frameless] ( + val dataset: Dataset[T] + )(implicit + val encoder: TypedEncoder[T]) extends TypedDatasetForwarded[T] { self => private implicit val spark: SparkSession = dataset.sparkSession - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A](ca: TypedAggregate[T, A]): TypedDataset[A] = { implicit val ea = ca.uencoder val tuple1: TypedDataset[Tuple1[A]] = aggMany(ca) @@ -42,10 +64,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val TypedEncoder[A].catalystRepr match { case StructType(_) => // if column is struct, we use all its fields - val df = tuple1 - .dataset - .selectExpr("_1.*") - .as[A](TypedExpressionEncoder[A]) + val df = + tuple1.dataset.selectExpr("_1.*").as[A](TypedExpressionEncoder[A]) TypedDataset.create(df) case other => @@ -54,52 +74,59 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A, B]( - ca: TypedAggregate[T, A], - cb: TypedAggregate[T, B] - ): TypedDataset[(A, B)] = { + ca: TypedAggregate[T, A], + cb: TypedAggregate[T, B] + ): TypedDataset[(A, B)] = { implicit val (ea, eb) = (ca.uencoder, cb.uencoder) aggMany(ca, cb) } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A, B, C]( - ca: TypedAggregate[T, A], - cb: TypedAggregate[T, B], - cc: TypedAggregate[T, C] - ): TypedDataset[(A, B, C)] = { + ca: TypedAggregate[T, A], + cb: TypedAggregate[T, B], + cc: TypedAggregate[T, C] + ): TypedDataset[(A, B, C)] = { implicit val (ea, eb, ec) = (ca.uencoder, cb.uencoder, cc.uencoder) aggMany(ca, cb, cc) } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ def agg[A, B, C, D]( - ca: TypedAggregate[T, A], - cb: TypedAggregate[T, B], - cc: TypedAggregate[T, C], - cd: TypedAggregate[T, D] - ): TypedDataset[(A, B, C, D)] = { - implicit val (ea, eb, ec, ed) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) + ca: TypedAggregate[T, A], + cb: TypedAggregate[T, B], + cc: TypedAggregate[T, C], + cd: TypedAggregate[T, D] + ): TypedDataset[(A, B, C, D)] = { + implicit val (ea, eb, ec, ed) = + (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) aggMany(ca, cb, cc, cd) } - /** Aggregates on the entire Dataset without groups. - * - * apache/spark - */ + /** + * Aggregates on the entire Dataset without groups. + * + * apache/spark + */ object aggMany extends ProductArgs { - def applyProduct[U <: HList, Out0 <: HList, Out](columns: U) - (implicit + + def applyProduct[U <: HList, Out0 <: HList, Out]( + columns: U + )(implicit i0: AggregateTypes.Aux[T, U, Out0], i1: ToTraversable.Aux[U, List, UntypedExpression[T]], i2: Tupler.Aux[Out0, Out], @@ -109,7 +136,7 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val val underlyingColumns = columns.toList[UntypedExpression[T]] val cols: Seq[Column] = for { (c, i) <- columns.toList[UntypedExpression[T]].zipWithIndex - } yield new Column(c.expr).as(s"_${i+1}") + } yield new Column(c.expr).as(s"_${i + 1}") // Workaround to SPARK-20346. One alternative is to allow the result to be Vector(null) for empty DataFrames. // Another one would be to return an Option. @@ -117,129 +144,163 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val for { (c, i) <- underlyingColumns.zipWithIndex if !c.uencoder.nullable - } yield s"_${i+1} is not null" - ).mkString(" or ") + } yield s"_${i + 1} is not null" + ).mkString(" or ") - val selected = dataset.toDF().agg(cols.head, cols.tail:_*).as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](if (filterStr.isEmpty) selected else selected.filter(filterStr)) + val selected = dataset + .toDF() + .agg(cols.head, cols.tail: _*) + .as[Out](TypedExpressionEncoder[Out]) + TypedDataset.create[Out]( + if (filterStr.isEmpty) selected else selected.filter(filterStr) + ) } } /** Returns a new [[TypedDataset]] where each record has been mapped on to the specified type. */ - def as[U]()(implicit as: As[T, U]): TypedDataset[U] = { + def as[U]( + )(implicit + as: As[T, U] + ): TypedDataset[U] = { implicit val uencoder = as.encoder TypedDataset.create(dataset.as[U](TypedExpressionEncoder[U])) } - /** Returns a checkpointed version of this [[TypedDataset]]. Checkpointing can be used to truncate the - * logical plan of this Dataset, which is especially useful in iterative algorithms where the - * plan may grow exponentially. It will be saved to files inside the checkpoint - * directory set with `SparkContext#setCheckpointDir`. - * - * Differs from `Dataset#checkpoint` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def checkpoint[F[_]](eager: Boolean)(implicit F: SparkDelay[F]): F[TypedDataset[T]] = + /** + * Returns a checkpointed version of this [[TypedDataset]]. Checkpointing can be used to truncate the + * logical plan of this Dataset, which is especially useful in iterative algorithms where the + * plan may grow exponentially. It will be saved to files inside the checkpoint + * directory set with `SparkContext#setCheckpointDir`. + * + * Differs from `Dataset#checkpoint` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def checkpoint[F[_]]( + eager: Boolean + )(implicit + F: SparkDelay[F] + ): F[TypedDataset[T]] = F.delay(TypedDataset.create[T](dataset.checkpoint(eager))) - /** Returns a new [[TypedDataset]] where each record has been mapped on to the specified type. - * Unlike `as` the projection U may include a subset of the columns of T and the column names and types must agree. - * - * {{{ - * case class Foo(i: Int, j: String) - * case class Bar(j: String) - * - * val t: TypedDataset[Foo] = ... - * val b: TypedDataset[Bar] = t.project[Bar] - * - * case class BarErr(e: String) - * // The following does not compile because `Foo` doesn't have a field with name `e` - * val e: TypedDataset[BarErr] = t.project[BarErr] - * }}} - */ - def project[U](implicit projector: SmartProject[T,U]): TypedDataset[U] = projector.apply(this) - - /** Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] - * combined. - * - * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analogous to `UNION ALL` in SQL. - * - * Differs from `Dataset#union` by aligning fields if possible. - * It will not compile if `Datasets` have not compatible schema. - * - * Example: - * {{{ - * case class Foo(x: Int, y: Long) - * case class Bar(y: Long, x: Int) - * case class Faz(x: Int, y: Int, z: Int) - * - * foo: TypedDataset[Foo] = ... - * bar: TypedDataset[Bar] = ... - * faz: TypedDataset[Faz] = ... - * - * foo union bar: TypedDataset[Foo] - * foo union faz: TypedDataset[Foo] - * // won't compile, you need to reverse order, you can't project from less fields to more - * faz union foo - * - * }}} - * - * apache/spark - */ - def union[U: TypedEncoder](other: TypedDataset[U])(implicit projector: SmartProject[U, T]): TypedDataset[T] = + /** + * Returns a new [[TypedDataset]] where each record has been mapped on to the specified type. + * Unlike `as` the projection U may include a subset of the columns of T and the column names and types must agree. + * + * {{{ + * case class Foo(i: Int, j: String) + * case class Bar(j: String) + * + * val t: TypedDataset[Foo] = ... + * val b: TypedDataset[Bar] = t.project[Bar] + * + * case class BarErr(e: String) + * // The following does not compile because `Foo` doesn't have a field with name `e` + * val e: TypedDataset[BarErr] = t.project[BarErr] + * }}} + */ + def project[U]( + implicit + projector: SmartProject[T, U] + ): TypedDataset[U] = projector.apply(this) + + /** + * Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] + * combined. + * + * Note that, this function is not a typical set union operation, in that it does not eliminate + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. + * + * Differs from `Dataset#union` by aligning fields if possible. + * It will not compile if `Datasets` have not compatible schema. + * + * Example: + * {{{ + * case class Foo(x: Int, y: Long) + * case class Bar(y: Long, x: Int) + * case class Faz(x: Int, y: Int, z: Int) + * + * foo: TypedDataset[Foo] = ... + * bar: TypedDataset[Bar] = ... + * faz: TypedDataset[Faz] = ... + * + * foo union bar: TypedDataset[Foo] + * foo union faz: TypedDataset[Foo] + * // won't compile, you need to reverse order, you can't project from less fields to more + * faz union foo + * + * }}} + * + * apache/spark + */ + def union[U: TypedEncoder]( + other: TypedDataset[U] + )(implicit + projector: SmartProject[U, T] + ): TypedDataset[T] = TypedDataset.create(dataset.union(other.project[T].dataset)) - /** Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] - * combined. - * - * Note that, this function is not a typical set union operation, in that it does not eliminate - * duplicate items. As such, it is analogous to `UNION ALL` in SQL. - * - * apache/spark - */ + /** + * Returns a new [[TypedDataset]] that contains the elements of both this and the `other` [[TypedDataset]] + * combined. + * + * Note that, this function is not a typical set union operation, in that it does not eliminate + * duplicate items. As such, it is analogous to `UNION ALL` in SQL. + * + * apache/spark + */ def union(other: TypedDataset[T]): TypedDataset[T] = { TypedDataset.create(dataset.union(other.dataset)) } - /** Returns the number of elements in the [[TypedDataset]]. - * - * Differs from `Dataset#count` by wrapping its result into an effect-suspending `F[_]`. - */ - def count[F[_]]()(implicit F: SparkDelay[F]): F[Long] = + /** + * Returns the number of elements in the [[TypedDataset]]. + * + * Differs from `Dataset#count` by wrapping its result into an effect-suspending `F[_]`. + */ + def count[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Long] = F.delay(dataset.count()) - /** Returns `TypedColumn` of type `A` given its name (alias for `col`). - * - * {{{ - * tf('id) - * }}} - * - * It is statically checked that column with such name exists and has type `A`. - */ - def apply[A](column: Witness.Lt[Symbol]) - (implicit + /** + * Returns `TypedColumn` of type `A` given its name (alias for `col`). + * + * {{{ + * tf('id) + * }}} + * + * It is statically checked that column with such name exists and has type `A`. + */ + def apply[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = col(column) - /** Returns `TypedColumn` of type `A` given its name. - * - * {{{ - * tf.col('id) - * }}} - * - * It is statically checked that column with such name exists and has type `A`. - */ - def col[A](column: Witness.Lt[Symbol]) - (implicit + /** + * Returns `TypedColumn` of type `A` given its name. + * + * {{{ + * tf.col('id) + * }}} + * + * It is statically checked that column with such name exists and has type `A`. + */ + def col[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = - new TypedColumn[T, A](dataset(column.value.name).as[A](TypedExpressionEncoder[A])) + new TypedColumn[T, A]( + dataset(column.value.name).as[A](TypedExpressionEncoder[A]) + ) - /** Returns `TypedColumn` of type `A` given a lambda indicating the field. + /** + * Returns `TypedColumn` of type `A` given a lambda indicating the field. * * {{{ * td.col(_.id) @@ -250,12 +311,13 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val def col[A](x: Function1[T, A]): TypedColumn[T, A] = macro TypedColumnMacroImpl.applyImpl[T, A] - /** Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`. - * {{{ - * ts: TypedDataset[Foo] = ... - * ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)] - * }}} - */ + /** + * Projects the entire `TypedDataset[T]` into a single column of type `TypedColumn[T,T]`. + * {{{ + * ts: TypedDataset[Foo] = ... + * ts.select(ts.asCol, ts.asCol): TypedDataset[(Foo,Foo)] + * }}} + */ def asCol: TypedColumn[T, T] = { val projectedColumn: Column = encoder.catalystRepr match { case StructType(_) => @@ -265,78 +327,98 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val case _ => dataset.col(dataset.columns.head) } - - new TypedColumn[T,T](projectedColumn) + + new TypedColumn[T, T](projectedColumn) } - /** References the entire `TypedDataset[T]` as a single column - * of type `TypedColumn[T,T]` so it can be used in a join operation. - * - * {{{ - * def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) = - * ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue) - * }}} - */ - def asJoinColValue(implicit i0: IsValueClass[T]): TypedColumn[T, T] = { + /** + * References the entire `TypedDataset[T]` as a single column + * of type `TypedColumn[T,T]` so it can be used in a join operation. + * + * {{{ + * def nameJoin(ds1: TypedDataset[Person], ds2: TypedDataset[Name]) = + * ds1.joinLeftSemi(ds2)(ds1.col('name) === ds2.asJoinColValue) + * }}} + */ + def asJoinColValue( + implicit + i0: IsValueClass[T] + ): TypedColumn[T, T] = { import _root_.frameless.syntax._ dataset.col("value").typedColumn } object colMany extends SingletonProductArgs { - def applyProduct[U <: HList, Out](columns: U) - (implicit + + def applyProduct[U <: HList, Out]( + columns: U + )(implicit i0: TypedColumn.ExistsMany[T, U, Out], i1: TypedEncoder[Out], i2: ToTraversable.Aux[U, List, Symbol] ): TypedColumn[T, Out] = { - val names = columns.toList[Symbol].map(_.name) - val colExpr = FramelessInternals.resolveExpr(dataset, names) - new TypedColumn[T, Out](colExpr) - } + val names = columns.toList[Symbol].map(_.name) + val colExpr = FramelessInternals.resolveExpr(dataset, names) + new TypedColumn[T, Out](colExpr) + } } - /** Right hand side disambiguation of `col` for join expressions. - * To be used when writting self-joins, noop in other circumstances. - * - * Note: In vanilla Spark, disambiguation in self-joins is acheaved using - * String based aliases, which is obviously unsafe. - */ - def colRight[A](column: Witness.Lt[Symbol]) - (implicit + /** + * Right hand side disambiguation of `col` for join expressions. + * To be used when writting self-joins, noop in other circumstances. + * + * Note: In vanilla Spark, disambiguation in self-joins is acheaved using + * String based aliases, which is obviously unsafe. + */ + def colRight[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = - new TypedColumn[T, A](FramelessInternals.DisambiguateRight(col(column).expr)) - - /** Left hand side disambiguation of `col` for join expressions. - * To be used when writting self-joins, noop in other circumstances. - * - * Note: In vanilla Spark, disambiguation in self-joins is acheaved using - * String based aliases, which is obviously unsafe. - */ - def colLeft[A](column: Witness.Lt[Symbol]) - (implicit + new TypedColumn[T, A]( + FramelessInternals.DisambiguateRight(col(column).expr) + ) + + /** + * Left hand side disambiguation of `col` for join expressions. + * To be used when writting self-joins, noop in other circumstances. + * + * Note: In vanilla Spark, disambiguation in self-joins is acheaved using + * String based aliases, which is obviously unsafe. + */ + def colLeft[A]( + column: Witness.Lt[Symbol] + )(implicit i0: TypedColumn.Exists[T, column.T, A], i1: TypedEncoder[A] ): TypedColumn[T, A] = - new TypedColumn[T, A](FramelessInternals.DisambiguateLeft(col(column).expr)) - - /** Returns a `Seq` that contains all the elements in this [[TypedDataset]]. - * - * Running this operation requires moving all the data into the application's driver process, and - * doing so on a very large [[TypedDataset]] can crash the driver process with OutOfMemoryError. - * - * Differs from `Dataset#collect` by wrapping its result into an effect-suspending `F[_]`. - */ - def collect[F[_]]()(implicit F: SparkDelay[F]): F[Seq[T]] = + new TypedColumn[T, A](FramelessInternals.DisambiguateLeft(col(column).expr)) + + /** + * Returns a `Seq` that contains all the elements in this [[TypedDataset]]. + * + * Running this operation requires moving all the data into the application's driver process, and + * doing so on a very large [[TypedDataset]] can crash the driver process with OutOfMemoryError. + * + * Differs from `Dataset#collect` by wrapping its result into an effect-suspending `F[_]`. + */ + def collect[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Seq[T]] = F.delay(dataset.collect().toSeq) - /** Optionally returns the first element in this [[TypedDataset]]. - * - * Differs from `Dataset#first` by wrapping its result into an `Option` and an effect-suspending `F[_]`. - */ - def firstOption[F[_]]()(implicit F: SparkDelay[F]): F[Option[T]] = + /** + * Optionally returns the first element in this [[TypedDataset]]. + * + * Differs from `Dataset#first` by wrapping its result into an `Option` and an effect-suspending `F[_]`. + */ + def firstOption[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Option[T]] = F.delay { try { Option(dataset.first()) @@ -345,354 +427,462 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } - /** Returns the first `num` elements of this [[TypedDataset]] as a `Seq`. - * - * Running take requires moving data into the application's driver process, and doing so with - * a very large `num` can crash the driver process with OutOfMemoryError. - * - * Differs from `Dataset#take` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def take[F[_]](num: Int)(implicit F: SparkDelay[F]): F[Seq[T]] = + /** + * Returns the first `num` elements of this [[TypedDataset]] as a `Seq`. + * + * Running take requires moving data into the application's driver process, and doing so with + * a very large `num` can crash the driver process with OutOfMemoryError. + * + * Differs from `Dataset#take` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def take[F[_]]( + num: Int + )(implicit + F: SparkDelay[F] + ): F[Seq[T]] = F.delay(dataset.take(num).toSeq) - /** Return an iterator that contains all rows in this [[TypedDataset]]. - * - * The iterator will consume as much memory as the largest partition in this [[TypedDataset]]. - * - * NOTE: this results in multiple Spark jobs, and if the input [[TypedDataset]] is the result - * of a wide transformation (e.g. join with different partitioners), to avoid - * recomputing the input [[TypedDataset]] should be cached first. - * - * Differs from `Dataset#toLocalIterator()` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def toLocalIterator[F[_]]()(implicit F: SparkDelay[F]): F[util.Iterator[T]] = + /** + * Return an iterator that contains all rows in this [[TypedDataset]]. + * + * The iterator will consume as much memory as the largest partition in this [[TypedDataset]]. + * + * NOTE: this results in multiple Spark jobs, and if the input [[TypedDataset]] is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input [[TypedDataset]] should be cached first. + * + * Differs from `Dataset#toLocalIterator()` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def toLocalIterator[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[util.Iterator[T]] = F.delay(dataset.toLocalIterator()) - /** Alias for firstOption(). - */ - def headOption[F[_]]()(implicit F: SparkDelay[F]): F[Option[T]] = firstOption() + /** + * Alias for firstOption(). + */ + def headOption[F[_]]( + )(implicit + F: SparkDelay[F] + ): F[Option[T]] = firstOption() - /** Alias for take(). - */ - def head[F[_]](num: Int)(implicit F: SparkDelay[F]): F[Seq[T]] = take(num) + /** + * Alias for take(). + */ + def head[F[_]]( + num: Int + )(implicit + F: SparkDelay[F] + ): F[Seq[T]] = take(num) // $COVERAGE-OFF$ - /** Alias for firstOption(). - */ - @deprecated("Method may throw exception. Use headOption or firstOption instead.", "0.5.0") + /** + * Alias for firstOption(). + */ + @deprecated( + "Method may throw exception. Use headOption or firstOption instead.", + "0.5.0" + ) def head: T = dataset.head() - /** Alias for firstOption(). - */ - @deprecated("Method may throw exception. Use headOption or firstOption instead.", "0.5.0") + /** + * Alias for firstOption(). + */ + @deprecated( + "Method may throw exception. Use headOption or firstOption instead.", + "0.5.0" + ) def first: T = dataset.head() // $COVERAGE-ONN$ - /** Displays the content of this [[TypedDataset]] in a tabular form. Strings more than 20 characters - * will be truncated, and all cells will be aligned right. For example: - * {{{ - * year month AVG('Adj Close) MAX('Adj Close) - * 1980 12 0.503218 0.595103 - * 1981 01 0.523289 0.570307 - * 1982 02 0.436504 0.475256 - * 1983 03 0.410516 0.442194 - * 1984 04 0.450090 0.483521 - * }}} - * @param numRows Number of rows to show - * @param truncate Whether truncate long strings. If true, strings more than 20 characters will - * be truncated and all cells will be aligned right - * - * Differs from `Dataset#show` by wrapping its result into an effect-suspending `F[_]`. - * - * apache/spark - */ - def show[F[_]](numRows: Int = 20, truncate: Boolean = true)(implicit F: SparkDelay[F]): F[Unit] = + /** + * Displays the content of this [[TypedDataset]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * Differs from `Dataset#show` by wrapping its result into an effect-suspending `F[_]`. + * + * apache/spark + */ + def show[F[_]]( + numRows: Int = 20, + truncate: Boolean = true + )(implicit + F: SparkDelay[F] + ): F[Unit] = F.delay(dataset.show(numRows, truncate)) - /** Returns a new [[frameless.TypedDataset]] that only contains elements where `column` is `true`. - * - * Differs from `TypedDatasetForward#filter` by taking a `TypedColumn[T, Boolean]` instead of a - * `T => Boolean`. Using a column expression instead of a regular function save one Spark → Scala - * deserialization which leads to better performance. - */ + /** + * Returns a new [[frameless.TypedDataset]] that only contains elements where `column` is `true`. + * + * Differs from `TypedDatasetForward#filter` by taking a `TypedColumn[T, Boolean]` instead of a + * `T => Boolean`. Using a column expression instead of a regular function save one Spark → Scala + * deserialization which leads to better performance. + */ def filter(column: TypedColumn[T, Boolean]): TypedDataset[T] = { - val filtered = dataset.toDF() - .filter(column.untyped) - .as[T](TypedExpressionEncoder[T]) + val filtered = + dataset.toDF().filter(column.untyped).as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](filtered) } - /** Runs `func` on each element of this [[TypedDataset]]. - * - * Differs from `Dataset#foreach` by wrapping its result into an effect-suspending `F[_]`. - */ - def foreach[F[_]](func: T => Unit)(implicit F: SparkDelay[F]): F[Unit] = + /** + * Runs `func` on each element of this [[TypedDataset]]. + * + * Differs from `Dataset#foreach` by wrapping its result into an effect-suspending `F[_]`. + */ + def foreach[F[_]]( + func: T => Unit + )(implicit + F: SparkDelay[F] + ): F[Unit] = F.delay(dataset.foreach(func)) - /** Runs `func` on each partition of this [[TypedDataset]]. - * - * Differs from `Dataset#foreachPartition` by wrapping its result into an effect-suspending `F[_]`. - */ - def foreachPartition[F[_]](func: Iterator[T] => Unit)(implicit F: SparkDelay[F]): F[Unit] = + /** + * Runs `func` on each partition of this [[TypedDataset]]. + * + * Differs from `Dataset#foreachPartition` by wrapping its result into an effect-suspending `F[_]`. + */ + def foreachPartition[F[_]]( + func: Iterator[T] => Unit + )(implicit + F: SparkDelay[F] + ): F[Unit] = F.delay(dataset.foreachPartition(func)) /** - * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified column, - * so we can run aggregation on it. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified column, + * so we can run aggregation on it. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def cube[K1]( - c1: TypedColumn[T, K1] - ): Cube1Ops[K1, T] = new Cube1Ops[K1, T](this, c1) - - /** - * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1] + ): Cube1Ops[K1, T] = new Cube1Ops[K1, T](this, c1) + + /** + * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def cube[K1, K2]( - c1: TypedColumn[T, K1], - c2: TypedColumn[T, K2] - ): Cube2Ops[K1, K2, T] = new Cube2Ops[K1, K2, T](this, c1, c2) - - /** - * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * {{{ - * case class MyClass(a: Int, b: Int, c: Int) - * val ds: TypedDataset[MyClass] - - * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = - * ds.cubeMany(ds('a), ds('b)).agg(count[MyClass]()) - * - * // original dataset: - * a b c - * 10 20 1 - * 15 25 2 - * - * // after aggregation: - * _1 _2 _3 - * 15 null 1 - * 15 25 1 - * null null 2 - * null 25 1 - * null 20 1 - * 10 null 1 - * 10 20 1 - * - * }}} - * - * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1], + c2: TypedColumn[T, K2] + ): Cube2Ops[K1, K2, T] = new Cube2Ops[K1, K2, T](this, c1, c2) + + /** + * Create a multi-dimensional cube for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * {{{ + * case class MyClass(a: Int, b: Int, c: Int) + * val ds: TypedDataset[MyClass] + * + * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = + * ds.cubeMany(ds('a), ds('b)).agg(count[MyClass]()) + * + * // original dataset: + * a b c + * 10 20 1 + * 15 25 2 + * + * // after aggregation: + * _1 _2 _3 + * 15 null 1 + * 15 25 1 + * null null 2 + * null 25 1 + * null 20 1 + * 10 null 1 + * 10 20 1 + * + * }}} + * + * Differs from `Dataset#cube` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ object cubeMany extends ProductArgs { - def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK) - (implicit + + def applyProduct[TK <: HList, K <: HList, KT]( + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: Tupler.Aux[K, KT], i2: ToTraversable.Aux[TK, List, UntypedExpression[T]] - ): CubeManyOps[T, TK, K, KT] = new CubeManyOps[T, TK, K, KT](self, groupedBy) + ): CubeManyOps[T, TK, K, KT] = + new CubeManyOps[T, TK, K, KT](self, groupedBy) } /** - * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * apache/spark - */ + * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * apache/spark + */ def groupBy[K1]( - c1: TypedColumn[T, K1] - ): GroupedBy1Ops[K1, T] = new GroupedBy1Ops[K1, T](this, c1) + c1: TypedColumn[T, K1] + ): GroupedBy1Ops[K1, T] = new GroupedBy1Ops[K1, T](this, c1) /** - * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * apache/spark - */ + * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * apache/spark + */ def groupBy[K1, K2]( - c1: TypedColumn[T, K1], - c2: TypedColumn[T, K2] - ): GroupedBy2Ops[K1, K2, T] = new GroupedBy2Ops[K1, K2, T](this, c1, c2) - - /** - * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * {{{ - * case class MyClass(a: Int, b: Int, c: Int) - * val ds: TypedDataset[MyClass] - * - * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = - * ds.groupByMany(ds('a), ds('b)).agg(count[MyClass]()) - * - * // original dataset: - * a b c - * 10 20 1 - * 15 25 2 - * - * // after aggregation: - * _1 _2 _3 - * 10 20 1 - * 15 25 1 - * - * }}} - * - * apache/spark - */ + c1: TypedColumn[T, K1], + c2: TypedColumn[T, K2] + ): GroupedBy2Ops[K1, K2, T] = new GroupedBy2Ops[K1, K2, T](this, c1, c2) + + /** + * Groups the [[TypedDataset]] using the specified columns, so that we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * {{{ + * case class MyClass(a: Int, b: Int, c: Int) + * val ds: TypedDataset[MyClass] + * + * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = + * ds.groupByMany(ds('a), ds('b)).agg(count[MyClass]()) + * + * // original dataset: + * a b c + * 10 20 1 + * 15 25 2 + * + * // after aggregation: + * _1 _2 _3 + * 10 20 1 + * 15 25 1 + * + * }}} + * + * apache/spark + */ object groupByMany extends ProductArgs { - def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK) - (implicit + + def applyProduct[TK <: HList, K <: HList, KT]( + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: Tupler.Aux[K, KT], i2: ToTraversable.Aux[TK, List, UntypedExpression[T]] - ): GroupedByManyOps[T, TK, K, KT] = new GroupedByManyOps[T, TK, K, KT](self, groupedBy) + ): GroupedByManyOps[T, TK, K, KT] = + new GroupedByManyOps[T, TK, K, KT](self, groupedBy) } /** - * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified column, - * so we can run aggregation on it. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified column, + * so we can run aggregation on it. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def rollup[K1]( - c1: TypedColumn[T, K1] - ): Rollup1Ops[K1, T] = new Rollup1Ops[K1, T](this, c1) - - /** - * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1] + ): Rollup1Ops[K1, T] = new Rollup1Ops[K1, T](this, c1) + + /** + * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ def rollup[K1, K2]( - c1: TypedColumn[T, K1], - c2: TypedColumn[T, K2] - ): Rollup2Ops[K1, K2, T] = new Rollup2Ops[K1, K2, T](this, c1, c2) - - /** - * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, - * so we can run aggregation on them. - * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. - * - * {{{ - * case class MyClass(a: Int, b: Int, c: Int) - * val ds: TypedDataset[MyClass] - * - * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = - * ds.rollupMany(ds('a), ds('b)).agg(count[MyClass]()) - * - * // original dataset: - * a b c - * 10 20 1 - * 15 25 2 - * - * // after aggregation: - * _1 _2 _3 - * 15 null 1 - * 15 25 1 - * null null 2 - * 10 null 1 - * 10 20 1 - * - * }}} - * - * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. - * - * apache/spark - */ + c1: TypedColumn[T, K1], + c2: TypedColumn[T, K2] + ): Rollup2Ops[K1, K2, T] = new Rollup2Ops[K1, K2, T](this, c1, c2) + + /** + * Create a multi-dimensional rollup for the current [[TypedDataset]] using the specified columns, + * so we can run aggregation on them. + * See [[frameless.functions.AggregateFunctions]] for all the available aggregate functions. + * + * {{{ + * case class MyClass(a: Int, b: Int, c: Int) + * val ds: TypedDataset[MyClass] + * + * val cubeDataset: TypedDataset[(Option[A], Option[B], Long)] = + * ds.rollupMany(ds('a), ds('b)).agg(count[MyClass]()) + * + * // original dataset: + * a b c + * 10 20 1 + * 15 25 2 + * + * // after aggregation: + * _1 _2 _3 + * 15 null 1 + * 15 25 1 + * null null 2 + * 10 null 1 + * 10 20 1 + * + * }}} + * + * Differs from `Dataset#rollup` by wrapping values into `Option` instead of returning `null`. + * + * apache/spark + */ object rollupMany extends ProductArgs { - def applyProduct[TK <: HList, K <: HList, KT](groupedBy: TK) - (implicit + + def applyProduct[TK <: HList, K <: HList, KT]( + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: Tupler.Aux[K, KT], i2: ToTraversable.Aux[TK, List, UntypedExpression[T]] - ): RollupManyOps[T, TK, K, KT] = new RollupManyOps[T, TK, K, KT](self, groupedBy) + ): RollupManyOps[T, TK, K, KT] = + new RollupManyOps[T, TK, K, KT](self, groupedBy) } /** Computes the cartesian project of `this` `Dataset` with the `other` `Dataset` */ - def joinCross[U](other: TypedDataset[U]) - (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = - new TypedDataset(self.dataset.joinWith(other.dataset, new Column(Literal(true)), "cross")) - - /** Computes the full outer join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinFull[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(Option[T], Option[U])]): TypedDataset[(Option[T], Option[U])] = - new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "full") - .as[(Option[T], Option[U])](TypedExpressionEncoder[(Option[T], Option[U])])) - - /** Computes the inner join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinInner[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = { - import FramelessInternals._ - - val leftPlan = logicalPlan(dataset) - val rightPlan = logicalPlan(other.dataset) - val join = disambiguate(Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE)) - val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan) - val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)]) - - TypedDataset.create[(T, U)](joinedDs) - } + def joinCross[U]( + other: TypedDataset[U] + )(implicit + e: TypedEncoder[(T, U)] + ): TypedDataset[(T, U)] = + new TypedDataset( + self.dataset.joinWith(other.dataset, new Column(Literal(true)), "cross") + ) - /** Computes the left outer join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinLeft[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(T, Option[U])]): TypedDataset[(T, Option[U])] = - new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "left_outer") - .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])])) - - /** Computes the left semi join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinLeftSemi[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] = - new TypedDataset(self.dataset.join(other.dataset, condition.untyped, "leftsemi") - .as[T](TypedExpressionEncoder(encoder))) - - /** Computes the left anti join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinLeftAnti[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]): TypedDataset[T] = - new TypedDataset(self.dataset.join(other.dataset, condition.untyped, "leftanti") - .as[T](TypedExpressionEncoder(encoder))) - - /** Computes the right outer join of `this` `Dataset` with the `other` `Dataset`, - * returning a `Tuple2` for each pair where condition evaluates to true. - */ - def joinRight[U](other: TypedDataset[U])(condition: TypedColumn[T with U, Boolean]) - (implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] = - new TypedDataset(self.dataset.joinWith(other.dataset, condition.untyped, "right_outer") - .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)])) + /** + * Computes the full outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinFull[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(Option[T], Option[U])] + ): TypedDataset[(Option[T], Option[U])] = + new TypedDataset( + self.dataset + .joinWith(other.dataset, condition.untyped, "full") + .as[(Option[T], Option[U])]( + TypedExpressionEncoder[(Option[T], Option[U])] + ) + ) + + /** + * Computes the inner join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinInner[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(T, U)] + ): TypedDataset[(T, U)] = { + import FramelessInternals._ + + val leftPlan = logicalPlan(dataset) + val rightPlan = logicalPlan(other.dataset) + val join = disambiguate( + Join(leftPlan, rightPlan, Inner, Some(condition.expr), JoinHint.NONE) + ) + val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan) + val joinedDs = + mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)]) + + TypedDataset.create[(T, U)](joinedDs) + } + + /** + * Computes the left outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinLeft[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(T, Option[U])] + ): TypedDataset[(T, Option[U])] = + new TypedDataset( + self.dataset + .joinWith(other.dataset, condition.untyped, "left_outer") + .as[(T, Option[U])](TypedExpressionEncoder[(T, Option[U])]) + ) + + /** + * Computes the left semi join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinLeftSemi[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + ): TypedDataset[T] = + new TypedDataset( + self.dataset + .join(other.dataset, condition.untyped, "leftsemi") + .as[T](TypedExpressionEncoder(encoder)) + ) + + /** + * Computes the left anti join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinLeftAnti[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + ): TypedDataset[T] = + new TypedDataset( + self.dataset + .join(other.dataset, condition.untyped, "leftanti") + .as[T](TypedExpressionEncoder(encoder)) + ) + + /** + * Computes the right outer join of `this` `Dataset` with the `other` `Dataset`, + * returning a `Tuple2` for each pair where condition evaluates to true. + */ + def joinRight[U]( + other: TypedDataset[U] + )(condition: TypedColumn[T with U, Boolean] + )(implicit + e: TypedEncoder[(Option[T], U)] + ): TypedDataset[(Option[T], U)] = + new TypedDataset( + self.dataset + .joinWith(other.dataset, condition.untyped, "right_outer") + .as[(Option[T], U)](TypedExpressionEncoder[(Option[T], U)]) + ) private def disambiguate(join: Join): Join = { - val plan = FramelessInternals.ofRows(dataset.sparkSession, join).queryExecution.analyzed.asInstanceOf[Join] + val plan = FramelessInternals + .ofRows(dataset.sparkSession, join) + .queryExecution + .analyzed + .asInstanceOf[Join] val disambiguated = plan.condition.map(_.transform { case FramelessInternals.DisambiguateLeft(tagged: AttributeReference) => val leftDs = FramelessInternals.ofRows(spark, plan.left) @@ -707,43 +897,82 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val plan.copy(condition = disambiguated) } - /** Takes a function from A => R and converts it to a UDF for TypedColumn[T, A] => TypedColumn[T, R]. - */ - def makeUDF[A: TypedEncoder, R: TypedEncoder](f: A => R): - TypedColumn[T, A] => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, R: TypedEncoder](f: (A1, A2) => R): - (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2, A3) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2, A3, A4) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = functions.udf(f) - - /** Takes a function from (A1, A2, A3, A4, A5) => R and converts it to a UDF for - * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R]. - */ - def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, A5: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = functions.udf(f) - - /** Type-safe projection from type T to Tuple1[A] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Takes a function from A => R and converts it to a UDF for TypedColumn[T, A] => TypedColumn[T, R]. + */ + def makeUDF[A: TypedEncoder, R: TypedEncoder]( + f: A => R + ): TypedColumn[T, A] => TypedColumn[T, R] = functions.udf(f) + + /** + * Takes a function from (A1, A2) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R]. + */ + def makeUDF[A1: TypedEncoder, A2: TypedEncoder, R: TypedEncoder]( + f: (A1, A2) => R + ): (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = + functions.udf(f) + + /** + * Takes a function from (A1, A2, A3) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R]. + */ + def makeUDF[ + A1: TypedEncoder, + A2: TypedEncoder, + A3: TypedEncoder, + R: TypedEncoder + ](f: (A1, A2, A3) => R + ): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = + functions.udf(f) + + /** + * Takes a function from (A1, A2, A3, A4) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R]. + */ + def makeUDF[ + A1: TypedEncoder, + A2: TypedEncoder, + A3: TypedEncoder, + A4: TypedEncoder, + R: TypedEncoder + ](f: (A1, A2, A3, A4) => R + ): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3], + TypedColumn[T, A4] + ) => TypedColumn[T, R] = functions.udf(f) + + /** + * Takes a function from (A1, A2, A3, A4, A5) => R and converts it to a UDF for + * (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R]. + */ + def makeUDF[ + A1: TypedEncoder, + A2: TypedEncoder, + A3: TypedEncoder, + A4: TypedEncoder, + A5: TypedEncoder, + R: TypedEncoder + ](f: (A1, A2, A3, A4, A5) => R + ): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3], + TypedColumn[T, A4], + TypedColumn[T, A5] + ) => TypedColumn[T, R] = functions.udf(f) + + /** + * Type-safe projection from type T to Tuple1[A] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A]( - ca: TypedColumn[T, A] - ): TypedDataset[A] = { + ca: TypedColumn[T, A] + ): TypedDataset[A] = { implicit val ea = ca.uencoder val tuple1: TypedDataset[Tuple1[A]] = selectMany(ca) @@ -753,10 +982,8 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val TypedEncoder[A].catalystRepr match { case StructType(_) => // if column is struct, we use all its fields - val df = tuple1 - .dataset - .selectExpr("_1.*") - .as[A](TypedExpressionEncoder[A]) + val df = + tuple1.dataset.selectExpr("_1.*").as[A](TypedExpressionEncoder[A]) TypedDataset.create(df) case other => @@ -765,217 +992,288 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } } - /** Type-safe projection from type T to Tuple2[A,B] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple2[A,B] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B] - ): TypedDataset[(A, B)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B] + ): TypedDataset[(A, B)] = { implicit val (ea, eb) = (ca.uencoder, cb.uencoder) selectMany(ca, cb) } - /** Type-safe projection from type T to Tuple3[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple3[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C] - ): TypedDataset[(A, B, C)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C] + ): TypedDataset[(A, B, C)] = { implicit val (ea, eb, ec) = (ca.uencoder, cb.uencoder, cc.uencoder) selectMany(ca, cb, cc) } - /** Type-safe projection from type T to Tuple4[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple4[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D] - ): TypedDataset[(A, B, C, D)] = { - implicit val (ea, eb, ec, ed) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D] + ): TypedDataset[(A, B, C, D)] = { + implicit val (ea, eb, ec, ed) = + (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder) selectMany(ca, cb, cc, cd) } - /** Type-safe projection from type T to Tuple5[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple5[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E] - ): TypedDataset[(A, B, C, D, E)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E] + ): TypedDataset[(A, B, C, D, E)] = { implicit val (ea, eb, ec, ed, ee) = (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder) selectMany(ca, cb, cc, cd, ce) } - /** Type-safe projection from type T to Tuple6[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple6[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F] - ): TypedDataset[(A, B, C, D, E, F)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F] + ): TypedDataset[(A, B, C, D, E, F)] = { implicit val (ea, eb, ec, ed, ee, ef) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf) } - /** Type-safe projection from type T to Tuple7[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple7[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G] - ): TypedDataset[(A, B, C, D, E, F, G)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G] + ): TypedDataset[(A, B, C, D, E, F, G)] = { implicit val (ea, eb, ec, ed, ee, ef, eg) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg) } - /** Type-safe projection from type T to Tuple8[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple8[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G, H]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G], - ch: TypedColumn[T, H] - ): TypedDataset[(A, B, C, D, E, F, G, H)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G], + ch: TypedColumn[T, H] + ): TypedDataset[(A, B, C, D, E, F, G, H)] = { implicit val (ea, eb, ec, ed, ee, ef, eg, eh) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder, + ch.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg, ch) } - /** Type-safe projection from type T to Tuple9[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple9[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G, H, I]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G], - ch: TypedColumn[T, H], - ci: TypedColumn[T, I] - ): TypedDataset[(A, B, C, D, E, F, G, H, I)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G], + ch: TypedColumn[T, H], + ci: TypedColumn[T, I] + ): TypedDataset[(A, B, C, D, E, F, G, H, I)] = { implicit val (ea, eb, ec, ed, ee, ef, eg, eh, ei) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder, ci.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder, + ch.uencoder, + ci.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg, ch, ci) } - /** Type-safe projection from type T to Tuple10[A,B,...] - * {{{ - * d.select( d('a), d('a)+d('b), ... ) - * }}} - */ + /** + * Type-safe projection from type T to Tuple10[A,B,...] + * {{{ + * d.select( d('a), d('a)+d('b), ... ) + * }}} + */ def select[A, B, C, D, E, F, G, H, I, J]( - ca: TypedColumn[T, A], - cb: TypedColumn[T, B], - cc: TypedColumn[T, C], - cd: TypedColumn[T, D], - ce: TypedColumn[T, E], - cf: TypedColumn[T, F], - cg: TypedColumn[T, G], - ch: TypedColumn[T, H], - ci: TypedColumn[T, I], - cj: TypedColumn[T, J] - ): TypedDataset[(A, B, C, D, E, F, G, H, I, J)] = { + ca: TypedColumn[T, A], + cb: TypedColumn[T, B], + cc: TypedColumn[T, C], + cd: TypedColumn[T, D], + ce: TypedColumn[T, E], + cf: TypedColumn[T, F], + cg: TypedColumn[T, G], + ch: TypedColumn[T, H], + ci: TypedColumn[T, I], + cj: TypedColumn[T, J] + ): TypedDataset[(A, B, C, D, E, F, G, H, I, J)] = { implicit val (ea, eb, ec, ed, ee, ef, eg, eh, ei, ej) = - (ca.uencoder, cb.uencoder, cc.uencoder, cd.uencoder, ce.uencoder, cf.uencoder, cg.uencoder, ch.uencoder, ci.uencoder, cj.uencoder) + ( + ca.uencoder, + cb.uencoder, + cc.uencoder, + cd.uencoder, + ce.uencoder, + cf.uencoder, + cg.uencoder, + ch.uencoder, + ci.uencoder, + cj.uencoder + ) selectMany(ca, cb, cc, cd, ce, cf, cg, ch, ci, cj) } object selectMany extends ProductArgs { - def applyProduct[U <: HList, Out0 <: HList, Out](columns: U) - (implicit + + def applyProduct[U <: HList, Out0 <: HList, Out]( + columns: U + )(implicit i0: ColumnTypes.Aux[T, U, Out0], i1: ToTraversable.Aux[U, List, UntypedExpression[T]], i2: Tupler.Aux[Out0, Out], i3: TypedEncoder[Out] ): TypedDataset[Out] = { - val base = dataset.toDF() - .select(columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)):_*) - val selected = base.as[Out](TypedExpressionEncoder[Out]) + val base = dataset + .toDF() + .select( + columns.toList[UntypedExpression[T]].map(c => new Column(c.expr)): _* + ) + val selected = base.as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](selected) - } + TypedDataset.create[Out](selected) + } } /** Sort each partition in the dataset using the columns selected. */ - def sortWithinPartitions[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] = + def sortWithinPartitions[A: CatalystOrdered]( + ca: SortedTypedColumn[T, A] + ): TypedDataset[T] = sortWithinPartitionsMany(ca) /** Sort each partition in the dataset using the columns selected. */ def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B] - ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb) + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B] + ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb) /** Sort each partition in the dataset using the columns selected. */ - def sortWithinPartitions[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B], - cc: SortedTypedColumn[T, C] - ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc) - - /** Sort each partition in the dataset by the given column expressions - * Default sort order is ascending. - * {{{ - * d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc) - * }}} - */ + def sortWithinPartitions[ + A: CatalystOrdered, + B: CatalystOrdered, + C: CatalystOrdered + ](ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B], + cc: SortedTypedColumn[T, C] + ): TypedDataset[T] = sortWithinPartitionsMany(ca, cb, cc) + + /** + * Sort each partition in the dataset by the given column expressions + * Default sort order is ascending. + * {{{ + * d.sortWithinPartitionsMany(d('a), d('b).desc, d('c).asc) + * }}} + */ object sortWithinPartitionsMany extends ProductArgs { - def applyProduct[U <: HList, O <: HList](columns: U) - (implicit + + def applyProduct[U <: HList, O <: HList]( + columns: U + )(implicit i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] ): TypedDataset[T] = { - val sorted = dataset.toDF() - .sortWithinPartitions(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*) + val sorted = dataset + .toDF() + .sortWithinPartitions( + i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped): _* + ) .as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](sorted) @@ -983,273 +1281,316 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } /** Orders the TypedDataset using the column selected. */ - def orderBy[A: CatalystOrdered](ca: SortedTypedColumn[T, A]): TypedDataset[T] = + def orderBy[A: CatalystOrdered]( + ca: SortedTypedColumn[T, A] + ): TypedDataset[T] = orderByMany(ca) /** Orders the TypedDataset using the columns selected. */ def orderBy[A: CatalystOrdered, B: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B] - ): TypedDataset[T] = orderByMany(ca, cb) - - /** Orders the TypedDataset using the columns selected. */ - def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( - ca: SortedTypedColumn[T, A], - cb: SortedTypedColumn[T, B], - cc: SortedTypedColumn[T, C] - ): TypedDataset[T] = orderByMany(ca, cb, cc) - - /** Sort the dataset by any number of column expressions. - * Default sort order is ascending. - * {{{ - * d.orderByMany(d('a), d('b).desc, d('c).asc) - * }}} - */ + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B] + ): TypedDataset[T] = orderByMany(ca, cb) + + /** Orders the TypedDataset using the columns selected. */ + def orderBy[A: CatalystOrdered, B: CatalystOrdered, C: CatalystOrdered]( + ca: SortedTypedColumn[T, A], + cb: SortedTypedColumn[T, B], + cc: SortedTypedColumn[T, C] + ): TypedDataset[T] = orderByMany(ca, cb, cc) + + /** + * Sort the dataset by any number of column expressions. + * Default sort order is ascending. + * {{{ + * d.orderByMany(d('a), d('b).desc, d('c).asc) + * }}} + */ object orderByMany extends ProductArgs { - def applyProduct[U <: HList, O <: HList](columns: U) - (implicit + + def applyProduct[U <: HList, O <: HList]( + columns: U + )(implicit i0: Mapper.Aux[SortedTypedColumn.defaultAscendingPoly.type, U, O], i1: ToTraversable.Aux[O, List, SortedTypedColumn[T, _]] ): TypedDataset[T] = { - val sorted = dataset.toDF() - .orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped):_*) + val sorted = dataset + .toDF() + .orderBy(i0(columns).toList[SortedTypedColumn[T, _]].map(_.untyped): _*) .as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](sorted) } } - /** Returns a new Dataset as a tuple with the specified - * column dropped. - * Does not allow for dropping from a single column TypedDataset - * - * {{{ - * val d: TypedDataset[Foo(a: String, b: Int...)] = ??? - * val result = TypedDataset[(Int, ...)] = d.drop('a) - * }}} - * @param column column to drop specified as a Symbol - * @param i0 LabelledGeneric derived for T - * @param i1 Remover derived for TRep and column - * @param i2 values of T with column removed - * @param i3 tupler of values - * @param i4 evidence of encoder of the tupled values - * @tparam Out Tupled return type - * @tparam TRep shapeless' record representation of T - * @tparam Removed record of T with column removed - * @tparam ValuesFromRemoved values of T with column removed as an HList - * @tparam V value type of column in T - * @return - */ - def dropTupled[Out, TRep <: HList, Removed <: HList, ValuesFromRemoved <: HList, V] - (column: Witness.Lt[Symbol]) - (implicit + /** + * Returns a new Dataset as a tuple with the specified + * column dropped. + * Does not allow for dropping from a single column TypedDataset + * + * {{{ + * val d: TypedDataset[Foo(a: String, b: Int...)] = ??? + * val result = TypedDataset[(Int, ...)] = d.drop('a) + * }}} + * @param column column to drop specified as a Symbol + * @param i0 LabelledGeneric derived for T + * @param i1 Remover derived for TRep and column + * @param i2 values of T with column removed + * @param i3 tupler of values + * @param i4 evidence of encoder of the tupled values + * @tparam Out Tupled return type + * @tparam TRep shapeless' record representation of T + * @tparam Removed record of T with column removed + * @tparam ValuesFromRemoved values of T with column removed as an HList + * @tparam V value type of column in T + * @return + */ + def dropTupled[ + Out, + TRep <: HList, + Removed <: HList, + ValuesFromRemoved <: HList, + V + ](column: Witness.Lt[Symbol] + )(implicit i0: LabelledGeneric.Aux[T, TRep], i1: Remover.Aux[TRep, column.T, (V, Removed)], i2: Values.Aux[Removed, ValuesFromRemoved], i3: Tupler.Aux[ValuesFromRemoved, Out], i4: TypedEncoder[Out] ): TypedDataset[Out] = { - val dropped = dataset - .toDF() - .drop(column.value.name) - .as[Out](TypedExpressionEncoder[Out]) + val dropped = dataset + .toDF() + .drop(column.value.name) + .as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](dropped) - } + TypedDataset.create[Out](dropped) + } /** - * Drops columns as necessary to return `U` - * - * @example - * {{{ - * case class X(i: Int, j: Int, k: Boolean) - * case class Y(i: Int, k: Boolean) - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.drop[Y] - * }}} - * - * @tparam U the output type - * - * @see [[frameless.TypedDataset#project]] - */ - def drop[U](implicit projector: SmartProject[T,U]): TypedDataset[U] = project[U] - - /** Prepends a new column to the Dataset. - * - * {{{ - * case class X(i: Int, j: Int) - * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) - * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10) - * }}} - */ - def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out] - (ca: TypedColumn[T, A]) - (implicit + * Drops columns as necessary to return `U` + * + * @example + * {{{ + * case class X(i: Int, j: Int, k: Boolean) + * case class Y(i: Int, k: Boolean) + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.drop[Y] + * }}} + * + * @tparam U the output type + * + * @see [[frameless.TypedDataset#project]] + */ + def drop[U]( + implicit + projector: SmartProject[T, U] + ): TypedDataset[U] = project[U] + + /** + * Prepends a new column to the Dataset. + * + * {{{ + * case class X(i: Int, j: Int) + * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + * val fNew: TypedDataset[(Int,Int,Boolean)] = f.withColumnTupled(f('j) === 10) + * }}} + */ + def withColumnTupled[A: TypedEncoder, H <: HList, FH <: HList, Out]( + ca: TypedColumn[T, A] + )(implicit i0: Generic.Aux[T, H], i1: Prepend.Aux[H, A :: HNil, FH], i2: Tupler.Aux[FH, Out], i3: TypedEncoder[Out] ): TypedDataset[Out] = { - // Giving a random name to the new column (the proper name will be given by the Tuple-based encoder) - val selected = dataset.toDF().withColumn("I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI", ca.untyped) - .as[Out](TypedExpressionEncoder[Out]) + // Giving a random name to the new column (the proper name will be given by the Tuple-based encoder) + val selected = dataset + .toDF() + .withColumn("I1X3T9CU1OP0128JYIO76TYZZA3AXHQ18RMI", ca.untyped) + .as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create[Out](selected) + TypedDataset.create[Out](selected) } - /** Returns a new [[frameless.TypedDataset]] with the specified column updated with a new value - * {{{ - * case class X(i: Int, j: Int) - * val f: TypedDataset[X] = TypedDataset.create(X(1,10) :: Nil) - * val fNew: TypedDataset[X] = f.withColumn('j, f('i)) // results in X(1, 1) :: Nil - * }}} - * @param column column given as a symbol to replace - * @param replacement column to replace the value with - * @param i0 Evidence that a column with the correct type and name exists - */ + /** + * Returns a new [[frameless.TypedDataset]] with the specified column updated with a new value + * {{{ + * case class X(i: Int, j: Int) + * val f: TypedDataset[X] = TypedDataset.create(X(1,10) :: Nil) + * val fNew: TypedDataset[X] = f.withColumn('j, f('i)) // results in X(1, 1) :: Nil + * }}} + * @param column column given as a symbol to replace + * @param replacement column to replace the value with + * @param i0 Evidence that a column with the correct type and name exists + */ def withColumnReplaced[A]( - column: Witness.Lt[Symbol], - replacement: TypedColumn[T, A] - )(implicit - i0: TypedColumn.Exists[T, column.T, A] - ): TypedDataset[T] = { - val updated = dataset.toDF().withColumn(column.value.name, replacement.untyped) + column: Witness.Lt[Symbol], + replacement: TypedColumn[T, A] + )(implicit + i0: TypedColumn.Exists[T, column.T, A] + ): TypedDataset[T] = { + val updated = dataset + .toDF() + .withColumn(column.value.name, replacement.untyped) .as[T](TypedExpressionEncoder[T]) TypedDataset.create[T](updated) } - /** Adds a column to a Dataset so long as the specified output type, `U`, has - * an extra column from `T` that has type `A`. - * - * @example - * {{{ - * case class X(i: Int, j: Int) - * case class Y(i: Int, j: Int, k: Boolean) - * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) - * val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10) - * }}} - * @param ca The typed column to add - * @param i0 TypeEncoder for output type U - * @param i1 TypeEncoder for added column type A - * @param i2 the LabelledGeneric derived for T - * @param i3 the LabelledGeneric derived for U - * @param i4 proof no fields have been removed - * @param i5 diff from T to U - * @param i6 keys from newFields - * @param i7 the one and only new key - * @param i8 the one and only new field enforcing the type of A exists - * @param i9 the keys of U - * @param iA allows for traversing the keys of U - * @tparam U the output type - * @tparam A The added column type - * @tparam TRep shapeless' record representation of T - * @tparam URep shapeless' record representation of U - * @tparam UKeys the keys of U as an HList - * @tparam NewFields the added fields to T to get U - * @tparam NewKeys the keys of NewFields as an HList - * @tparam NewKey the first, and only, key in NewKey - * - * @see [[frameless.TypedDataset.WithColumnApply#apply]] - */ + /** + * Adds a column to a Dataset so long as the specified output type, `U`, has + * an extra column from `T` that has type `A`. + * + * @example + * {{{ + * case class X(i: Int, j: Int) + * case class Y(i: Int, j: Int, k: Boolean) + * val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + * val fNew: TypedDataset[Y] = f.withColumn[Y](f('j) === 10) + * }}} + * @param ca The typed column to add + * @param i0 TypeEncoder for output type U + * @param i1 TypeEncoder for added column type A + * @param i2 the LabelledGeneric derived for T + * @param i3 the LabelledGeneric derived for U + * @param i4 proof no fields have been removed + * @param i5 diff from T to U + * @param i6 keys from newFields + * @param i7 the one and only new key + * @param i8 the one and only new field enforcing the type of A exists + * @param i9 the keys of U + * @param iA allows for traversing the keys of U + * @tparam U the output type + * @tparam A The added column type + * @tparam TRep shapeless' record representation of T + * @tparam URep shapeless' record representation of U + * @tparam UKeys the keys of U as an HList + * @tparam NewFields the added fields to T to get U + * @tparam NewKeys the keys of NewFields as an HList + * @tparam NewKey the first, and only, key in NewKey + * + * @see [[frameless.TypedDataset.WithColumnApply#apply]] + */ def withColumn[U] = new WithColumnApply[U] class WithColumnApply[U] { - def apply[A, TRep <: HList, URep <: HList, UKeys <: HList, NewFields <: HList, NewKeys <: HList, NewKey <: Symbol] - (ca: TypedColumn[T, A]) - (implicit - i0: TypedEncoder[U], - i1: TypedEncoder[A], - i2: LabelledGeneric.Aux[T, TRep], - i3: LabelledGeneric.Aux[U, URep], - i4: Diff.Aux[TRep, URep, HNil], - i5: Diff.Aux[URep, TRep, NewFields], - i6: Keys.Aux[NewFields, NewKeys], - i7: IsHCons.Aux[NewKeys, NewKey, HNil], - i8: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil], - i9: Keys.Aux[URep, UKeys], - iA: ToTraversable.Aux[UKeys, Seq, Symbol] - ): TypedDataset[U] = { + + def apply[ + A, + TRep <: HList, + URep <: HList, + UKeys <: HList, + NewFields <: HList, + NewKeys <: HList, + NewKey <: Symbol + ](ca: TypedColumn[T, A] + )(implicit + i0: TypedEncoder[U], + i1: TypedEncoder[A], + i2: LabelledGeneric.Aux[T, TRep], + i3: LabelledGeneric.Aux[U, URep], + i4: Diff.Aux[TRep, URep, HNil], + i5: Diff.Aux[URep, TRep, NewFields], + i6: Keys.Aux[NewFields, NewKeys], + i7: IsHCons.Aux[NewKeys, NewKey, HNil], + i8: IsHCons.Aux[NewFields, FieldType[NewKey, A], HNil], + i9: Keys.Aux[URep, UKeys], + iA: ToTraversable.Aux[UKeys, Seq, Symbol] + ): TypedDataset[U] = { val newColumnName = i7.head(i6()).name - val dfWithNewColumn = dataset - .toDF() - .withColumn(newColumnName, ca.untyped) + val dfWithNewColumn = dataset.toDF().withColumn(newColumnName, ca.untyped) val newColumns = i9.apply().to[Seq].map(_.name).map(dfWithNewColumn.col) - val selected = dfWithNewColumn - .select(newColumns: _*) - .as[U](TypedExpressionEncoder[U]) + val selected = + dfWithNewColumn.select(newColumns: _*).as[U](TypedExpressionEncoder[U]) TypedDataset.create[U](selected) } } /** - * Explodes a single column at a time. It only compiles if the type of column supports this operation. - * - * @example - * - * {{{ - * case class X(i: Int, j: Array[Int]) - * case class Y(i: Int, j: Int) - * - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.explode('j).as[Y] - * }}} - * @param column the column we wish to explode - */ - def explode[A, TRep <: HList, V[_], OutMod <: HList, OutModValues <: HList, Out] - (column: Witness.Lt[Symbol]) - (implicit - i0: TypedColumn.Exists[T, column.T, V[A]], - i1: TypedEncoder[A], - i2: CatalystExplodableCollection[V], - i3: LabelledGeneric.Aux[T, TRep], - i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], - i5: Values.Aux[OutMod, OutModValues], - i6: Tupler.Aux[OutModValues, Out], - i7: TypedEncoder[Out] - ): TypedDataset[Out] = { - import org.apache.spark.sql.functions.{explode => sparkExplode} + * Explodes a single column at a time. It only compiles if the type of column supports this operation. + * + * @example + * + * {{{ + * case class X(i: Int, j: Array[Int]) + * case class Y(i: Int, j: Int) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.explode('j).as[Y] + * }}} + * @param column the column we wish to explode + */ + def explode[ + A, + TRep <: HList, + V[_], + OutMod <: HList, + OutModValues <: HList, + Out + ](column: Witness.Lt[Symbol] + )(implicit + i0: TypedColumn.Exists[T, column.T, V[A]], + i1: TypedEncoder[A], + i2: CatalystExplodableCollection[V], + i3: LabelledGeneric.Aux[T, TRep], + i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], + i5: Values.Aux[OutMod, OutModValues], + i6: Tupler.Aux[OutModValues, Out], + i7: TypedEncoder[Out] + ): TypedDataset[Out] = { + import org.apache.spark.sql.functions.{ explode => sparkExplode } val df = dataset.toDF() val trans = - df - .withColumn(column.value.name, sparkExplode(df(column.value.name))) + df.withColumn(column.value.name, sparkExplode(df(column.value.name))) .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } /** - * Explodes a single column at a time. It only compiles if the type of column supports this operation. - * - * @example - * - * {{{ - * case class X(i: Int, j: Map[Int, Int]) - * case class Y(i: Int, j: (Int, Int)) - * - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y] - * }}} - * @param column the column we wish to explode - */ - def explodeMap[A, B, V[_, _], TRep <: HList, OutMod <: HList, OutModValues <: HList, Out] - (column: Witness.Lt[Symbol]) - (implicit - i0: TypedColumn.Exists[T, column.T, V[A, B]], - i1: TypedEncoder[A], - i2: TypedEncoder[B], - i3: LabelledGeneric.Aux[T, TRep], - i4: Modifier.Aux[TRep, column.T, V[A,B], (A, B), OutMod], - i5: Values.Aux[OutMod, OutModValues], - i6: Tupler.Aux[OutModValues, Out], - i7: TypedEncoder[Out] - ): TypedDataset[Out] = { - import org.apache.spark.sql.functions.{explode => sparkExplode, struct => sparkStruct, col => sparkCol} + * Explodes a single column at a time. It only compiles if the type of column supports this operation. + * + * @example + * + * {{{ + * case class X(i: Int, j: Map[Int, Int]) + * case class Y(i: Int, j: (Int, Int)) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.explodeMap('j).as[Y] + * }}} + * @param column the column we wish to explode + */ + def explodeMap[ + A, + B, + V[_, _], + TRep <: HList, + OutMod <: HList, + OutModValues <: HList, + Out + ](column: Witness.Lt[Symbol] + )(implicit + i0: TypedColumn.Exists[T, column.T, V[A, B]], + i1: TypedEncoder[A], + i2: TypedEncoder[B], + i3: LabelledGeneric.Aux[T, TRep], + i4: Modifier.Aux[TRep, column.T, V[A, B], (A, B), OutMod], + i5: Values.Aux[OutMod, OutModValues], + i6: Tupler.Aux[OutModValues, Out], + i7: TypedEncoder[Out] + ): TypedDataset[Out] = { + import org.apache.spark.sql.functions.{ + explode => sparkExplode, + struct => sparkStruct, + col => sparkCol + } val df = dataset.toDF() // select all columns, all original columns and [key, value] columns appeared after the map explode @@ -1271,7 +1612,10 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val exploded // map explode explodes it into [key, value] columns // the only way to put it into a column is to create a struct - .withColumn(columnRenamed, sparkStruct(exploded("key"), exploded("value"))) + .withColumn( + columnRenamed, + sparkStruct(exploded("key"), exploded("value")) + ) // selecting only original columns, we don't need [key, value] columns left in the DataFrame after the map explode .select(columns: _*) // rename columns back and form the result @@ -1281,72 +1625,81 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val } /** - * Flattens a column of type Option[A]. Compiles only if the selected column is of type Option[A]. - * - * - * @example - * - * {{{ - * case class X(i: Int, j: Option[Int]) - * case class Y(i: Int, j: Int) - * - * val f: TypedDataset[X] = ??? - * val fNew: TypedDataset[Y] = f.flattenOption('j).as[Y] - * }}} - * - * @param column the column we wish to flatten - */ - def flattenOption[A, TRep <: HList, V[_], OutMod <: HList, OutModValues <: HList, Out] - (column: Witness.Lt[Symbol]) - (implicit - i0: TypedColumn.Exists[T, column.T, V[A]], - i1: TypedEncoder[A], - i2: V[A] =:= Option[A], - i3: LabelledGeneric.Aux[T, TRep], - i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], - i5: Values.Aux[OutMod, OutModValues], - i6: Tupler.Aux[OutModValues, Out], - i7: TypedEncoder[Out] - ): TypedDataset[Out] = { + * Flattens a column of type Option[A]. Compiles only if the selected column is of type Option[A]. + * + * @example + * + * {{{ + * case class X(i: Int, j: Option[Int]) + * case class Y(i: Int, j: Int) + * + * val f: TypedDataset[X] = ??? + * val fNew: TypedDataset[Y] = f.flattenOption('j).as[Y] + * }}} + * + * @param column the column we wish to flatten + */ + def flattenOption[ + A, + TRep <: HList, + V[_], + OutMod <: HList, + OutModValues <: HList, + Out + ](column: Witness.Lt[Symbol] + )(implicit + i0: TypedColumn.Exists[T, column.T, V[A]], + i1: TypedEncoder[A], + i2: V[A] =:= Option[A], + i3: LabelledGeneric.Aux[T, TRep], + i4: Modifier.Aux[TRep, column.T, V[A], A, OutMod], + i5: Values.Aux[OutMod, OutModValues], + i6: Tupler.Aux[OutModValues, Out], + i7: TypedEncoder[Out] + ): TypedDataset[Out] = { val df = dataset.toDF() - val trans = df.filter(df(column.value.name).isNotNull). - as[Out](TypedExpressionEncoder[Out]) + val trans = df + .filter(df(column.value.name).isNotNull) + .as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](trans) } } object TypedDataset { - def create[A](data: Seq[A]) - (implicit + + def create[A]( + data: Seq[A] + )(implicit encoder: TypedEncoder[A], sqlContext: SparkSession ): TypedDataset[A] = { - val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) + val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) - TypedDataset.create[A](dataset) - } + TypedDataset.create[A](dataset) + } - def create[A](data: RDD[A]) - (implicit + def create[A]( + data: RDD[A] + )(implicit encoder: TypedEncoder[A], sqlContext: SparkSession ): TypedDataset[A] = { - val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) + val dataset = sqlContext.createDataset(data)(TypedExpressionEncoder[A]) - TypedDataset.create[A](dataset) - } + TypedDataset.create[A](dataset) + } def create[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = createUnsafe(dataset.toDF()) /** - * Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]]. - * Note that the names and types need to align! - * - * This is an unsafe operation: If the schemas do not align, - * the error will be captured at runtime (not during compilation). - */ + * Creates a [[frameless.TypedDataset]] from a Spark [[org.apache.spark.sql.DataFrame]]. + * Note that the names and types need to align! + * + * This is an unsafe operation: If the schemas do not align, + * the error will be captured at runtime (not during compilation). + */ def createUnsafe[A: TypedEncoder](df: DataFrame): TypedDataset[A] = { val e = TypedEncoder[A] val output: Seq[Attribute] = df.queryExecution.analyzed.output @@ -1358,7 +1711,8 @@ object TypedDataset { throw new IllegalStateException( s"Unsupported creation of TypedDataset with ${targetFields.size} column(s) " + s"from a DataFrame with ${output.size} columns. " + - "Try to `select()` the proper columns in the right order before calling `create()`.") + "Try to `select()` the proper columns in the right order before calling `create()`." + ) } // Adapt names if they are not the same (note: types still might not match) @@ -1368,7 +1722,7 @@ object TypedDataset { val canSelect = targetColNames.toSet.subsetOf(output.map(_.name).toSet) val reshaped = if (shouldReshape && canSelect) { - df.select(targetColNames.head, targetColNames.tail:_*) + df.select(targetColNames.head, targetColNames.tail: _*) } else if (shouldReshape) { df.toDF(targetColNames: _*) } else { @@ -1378,9 +1732,14 @@ object TypedDataset { new TypedDataset[A](reshaped.as[A](TypedExpressionEncoder[A])) } - /** Prefer `TypedDataset.create` over `TypedDataset.unsafeCreate` unless you - * know what you are doing. */ - @deprecated("Prefer TypedDataset.create over TypedDataset.unsafeCreate", "0.3.0") + /** + * Prefer `TypedDataset.create` over `TypedDataset.unsafeCreate` unless you + * know what you are doing. + */ + @deprecated( + "Prefer TypedDataset.create over TypedDataset.unsafeCreate", + "0.3.0" + ) def unsafeCreate[A: TypedEncoder](dataset: Dataset[A]): TypedDataset[A] = { new TypedDataset[A](dataset) } diff --git a/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala b/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala index d417caf8e..b4beac7bf 100644 --- a/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala +++ b/dataset/src/main/scala/frameless/TypedDatasetForwarded.scala @@ -6,366 +6,428 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, DataFrameWriter, SQLContext, SparkSession} +import org.apache.spark.sql.{ + DataFrame, + DataFrameWriter, + SQLContext, + SparkSession +} import org.apache.spark.storage.StorageLevel import scala.util.Random -/** This trait implements [[TypedDataset]] methods that have the same signature - * than their `Dataset` equivalent. Each method simply forwards the call to the - * underlying `Dataset`. - * - * Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - */ +/** + * This trait implements [[TypedDataset]] methods that have the same signature + * than their `Dataset` equivalent. Each method simply forwards the call to the + * underlying `Dataset`. + * + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + */ trait TypedDatasetForwarded[T] { self: TypedDataset[T] => override def toString: String = dataset.toString /** - * Returns a `SparkSession` from this [[TypedDataset]]. - */ + * Returns a `SparkSession` from this [[TypedDataset]]. + */ def sparkSession: SparkSession = dataset.sparkSession /** - * Returns a `SQLContext` from this [[TypedDataset]]. - */ + * Returns a `SQLContext` from this [[TypedDataset]]. + */ def sqlContext: SQLContext = dataset.sqlContext /** - * Returns the schema of this Dataset. - * - * apache/spark - */ + * Returns the schema of this Dataset. + * + * apache/spark + */ def schema: StructType = dataset.schema - /** Prints the schema of the underlying `Dataset` to the console in a nice tree format. - * - * apache/spark + /** + * Prints the schema of the underlying `Dataset` to the console in a nice tree format. + * + * apache/spark */ def printSchema(): Unit = dataset.printSchema() - /** Prints the plans (logical and physical) to the console for debugging purposes. - * - * apache/spark + /** + * Prints the plans (logical and physical) to the console for debugging purposes. + * + * apache/spark */ def explain(extended: Boolean = false): Unit = dataset.explain(extended) /** - * Returns a `QueryExecution` from this [[TypedDataset]]. - * - * It is the primary workflow for executing relational queries using Spark. Designed to allow easy - * access to the intermediate phases of query execution for developers. - * - * apache/spark - */ + * Returns a `QueryExecution` from this [[TypedDataset]]. + * + * It is the primary workflow for executing relational queries using Spark. Designed to allow easy + * access to the intermediate phases of query execution for developers. + * + * apache/spark + */ def queryExecution: QueryExecution = dataset.queryExecution - /** Converts this strongly typed collection of data to generic Dataframe. In contrast to the - * strongly typed objects that Dataset operations work on, a Dataframe returns generic Row - * objects that allow fields to be accessed by ordinal or name. - * - * apache/spark - */ + /** + * Converts this strongly typed collection of data to generic Dataframe. In contrast to the + * strongly typed objects that Dataset operations work on, a Dataframe returns generic Row + * objects that allow fields to be accessed by ordinal or name. + * + * apache/spark + */ def toDF(): DataFrame = dataset.toDF() - /** Converts this [[TypedDataset]] to an RDD. - * - * apache/spark - */ + /** + * Converts this [[TypedDataset]] to an RDD. + * + * apache/spark + */ def rdd: RDD[T] = dataset.rdd - /** Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions. - * - * apache/spark - */ + /** + * Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions. + * + * apache/spark + */ def repartition(numPartitions: Int): TypedDataset[T] = TypedDataset.create(dataset.repartition(numPartitions)) - /** - * Get the [[TypedDataset]]'s current storage level, or StorageLevel.NONE if not persisted. - * - * apache/spark - */ + * Get the [[TypedDataset]]'s current storage level, or StorageLevel.NONE if not persisted. + * + * apache/spark + */ def storageLevel(): StorageLevel = dataset.storageLevel /** - * Returns the content of the [[TypedDataset]] as a Dataset of JSON strings. - * - * apache/spark - */ + * Returns the content of the [[TypedDataset]] as a Dataset of JSON strings. + * + * apache/spark + */ def toJSON: TypedDataset[String] = TypedDataset.create(dataset.toJSON) /** - * Interface for saving the content of the non-streaming [[TypedDataset]] out into external storage. - * - * apache/spark - */ + * Interface for saving the content of the non-streaming [[TypedDataset]] out into external storage. + * + * apache/spark + */ def write: DataFrameWriter[T] = dataset.write /** - * Interface for saving the content of the streaming Dataset out into external storage. - * - * apache/spark - */ + * Interface for saving the content of the streaming Dataset out into external storage. + * + * apache/spark + */ def writeStream: DataStreamWriter[T] = dataset.writeStream - - /** Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions. - * Similar to coalesce defined on an RDD, this operation results in a narrow dependency, e.g. - * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of - * the 100 new partitions will claim 10 of the current partitions. - * - * apache/spark - */ + + /** + * Returns a new [[TypedDataset]] that has exactly `numPartitions` partitions. + * Similar to coalesce defined on an RDD, this operation results in a narrow dependency, e.g. + * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of + * the 100 new partitions will claim 10 of the current partitions. + * + * apache/spark + */ def coalesce(numPartitions: Int): TypedDataset[T] = TypedDataset.create(dataset.coalesce(numPartitions)) /** - * Returns an `Array` that contains all column names in this [[TypedDataset]]. - */ + * Returns an `Array` that contains all column names in this [[TypedDataset]]. + */ def columns: Array[String] = dataset.columns - /** Concise syntax for chaining custom transformations. - * - * apache/spark - */ + /** + * Concise syntax for chaining custom transformations. + * + * apache/spark + */ def transform[U](t: TypedDataset[T] => TypedDataset[U]): TypedDataset[U] = t(this) - /** Returns a new Dataset by taking the first `n` rows. The difference between this function - * and `head` is that `head` is an action and returns an array (by triggering query execution) - * while `limit` returns a new Dataset. - * - * apache/spark - */ + /** + * Returns a new Dataset by taking the first `n` rows. The difference between this function + * and `head` is that `head` is an action and returns an array (by triggering query execution) + * while `limit` returns a new Dataset. + * + * apache/spark + */ def limit(n: Int): TypedDataset[T] = TypedDataset.create(dataset.limit(n)) - /** Returns a new [[TypedDataset]] by sampling a fraction of records. - * - * apache/spark - */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long = Random.nextLong()): TypedDataset[T] = + /** + * Returns a new [[TypedDataset]] by sampling a fraction of records. + * + * apache/spark + */ + def sample( + withReplacement: Boolean, + fraction: Double, + seed: Long = Random.nextLong() + ): TypedDataset[T] = TypedDataset.create(dataset.sample(withReplacement, fraction, seed)) - /** Returns a new [[TypedDataset]] that contains only the unique elements of this [[TypedDataset]]. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * - * apache/spark - */ + /** + * Returns a new [[TypedDataset]] that contains only the unique elements of this [[TypedDataset]]. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * apache/spark + */ def distinct: TypedDataset[T] = TypedDataset.create(dataset.distinct()) /** - * Returns a best-effort snapshot of the files that compose this [[TypedDataset]]. This method simply - * asks each constituent BaseRelation for its respective files and takes the union of all results. - * Depending on the source relations, this may not find all input files. Duplicates are removed. - * - * apache/spark - */ + * Returns a best-effort snapshot of the files that compose this [[TypedDataset]]. This method simply + * asks each constituent BaseRelation for its respective files and takes the union of all results. + * Depending on the source relations, this may not find all input files. Duplicates are removed. + * + * apache/spark + */ def inputFiles: Array[String] = dataset.inputFiles /** - * Returns true if the `collect` and `take` methods can be run locally - * (without any Spark executors). - * - * apache/spark - */ + * Returns true if the `collect` and `take` methods can be run locally + * (without any Spark executors). + * + * apache/spark + */ def isLocal: Boolean = dataset.isLocal /** - * Returns true if this [[TypedDataset]] contains one or more sources that continuously - * return data as it arrives. A [[TypedDataset]] that reads data from a streaming source - * must be executed as a `StreamingQuery` using the `start()` method in - * `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or - * `collect()`, will throw an `AnalysisException` when there is a streaming - * source present. - * - * apache/spark - */ + * Returns true if this [[TypedDataset]] contains one or more sources that continuously + * return data as it arrives. A [[TypedDataset]] that reads data from a streaming source + * must be executed as a `StreamingQuery` using the `start()` method in + * `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or + * `collect()`, will throw an `AnalysisException` when there is a streaming + * source present. + * + * apache/spark + */ def isStreaming: Boolean = dataset.isStreaming - /** Returns a new [[TypedDataset]] that contains only the elements of this [[TypedDataset]] that are also - * present in `other`. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * - * apache/spark - */ + /** + * Returns a new [[TypedDataset]] that contains only the elements of this [[TypedDataset]] that are also + * present in `other`. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * apache/spark + */ def intersect(other: TypedDataset[T]): TypedDataset[T] = TypedDataset.create(dataset.intersect(other.dataset)) /** - * Randomly splits this [[TypedDataset]] with the provided weights. - * Weights for splits, will be normalized if they don't sum to 1. - * - * apache/spark - */ + * Randomly splits this [[TypedDataset]] with the provided weights. + * Weights for splits, will be normalized if they don't sum to 1. + * + * apache/spark + */ // $COVERAGE-OFF$ We can not test this method because it is non-deterministic. def randomSplit(weights: Array[Double]): Array[TypedDataset[T]] = dataset.randomSplit(weights).map(TypedDataset.create[T]) // $COVERAGE-ON$ /** - * Randomly splits this [[TypedDataset]] with the provided weights. - * Weights for splits, will be normalized if they don't sum to 1. - * - * apache/spark - */ + * Randomly splits this [[TypedDataset]] with the provided weights. + * Weights for splits, will be normalized if they don't sum to 1. + * + * apache/spark + */ def randomSplit(weights: Array[Double], seed: Long): Array[TypedDataset[T]] = dataset.randomSplit(weights, seed).map(TypedDataset.create[T]) /** - * Returns a Java list that contains randomly split [[TypedDataset]] with the provided weights. - * Weights for splits, will be normalized if they don't sum to 1. - * - * apache/spark - */ - def randomSplitAsList(weights: Array[Double], seed: Long): util.List[TypedDataset[T]] = { + * Returns a Java list that contains randomly split [[TypedDataset]] with the provided weights. + * Weights for splits, will be normalized if they don't sum to 1. + * + * apache/spark + */ + def randomSplitAsList( + weights: Array[Double], + seed: Long + ): util.List[TypedDataset[T]] = { val values = randomSplit(weights, seed) java.util.Arrays.asList(values: _*) } - - /** Returns a new Dataset containing rows in this Dataset but not in another Dataset. - * This is equivalent to `EXCEPT` in SQL. - * - * Note that, equality checking is performed directly on the encoded representation of the data - * and thus is not affected by a custom `equals` function defined on `T`. - * - * apache/spark - */ + /** + * Returns a new Dataset containing rows in this Dataset but not in another Dataset. + * This is equivalent to `EXCEPT` in SQL. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * + * apache/spark + */ def except(other: TypedDataset[T]): TypedDataset[T] = TypedDataset.create(dataset.except(other.dataset)) - /** Persist this [[TypedDataset]] with the default storage level (`MEMORY_AND_DISK`). - * - * apache/spark - */ + /** + * Persist this [[TypedDataset]] with the default storage level (`MEMORY_AND_DISK`). + * + * apache/spark + */ def cache(): TypedDataset[T] = TypedDataset.create(dataset.cache()) - /** Persist this [[TypedDataset]] with the given storage level. - * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, - * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc. - * - * apache/spark - */ - def persist(newLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK): TypedDataset[T] = + /** + * Persist this [[TypedDataset]] with the given storage level. + * @param newLevel One of: `MEMORY_ONLY`, `MEMORY_AND_DISK`, `MEMORY_ONLY_SER`, + * `MEMORY_AND_DISK_SER`, `DISK_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK_2`, etc. + * + * apache/spark + */ + def persist( + newLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK + ): TypedDataset[T] = TypedDataset.create(dataset.persist(newLevel)) - /** Mark the [[TypedDataset]] as non-persistent, and remove all blocks for it from memory and disk. - * @param blocking Whether to block until all blocks are deleted. - * - * apache/spark - */ + /** + * Mark the [[TypedDataset]] as non-persistent, and remove all blocks for it from memory and disk. + * @param blocking Whether to block until all blocks are deleted. + * + * apache/spark + */ def unpersist(blocking: Boolean = false): TypedDataset[T] = TypedDataset.create(dataset.unpersist(blocking)) // $COVERAGE-OFF$ We do not test deprecated method since forwarded methods are tested. - @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0") + @deprecated( + "deserialized methods have moved to a separate section to highlight their runtime overhead", + "0.4.0" + ) def map[U: TypedEncoder](func: T => U): TypedDataset[U] = deserialized.map(func) - @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0") - def mapPartitions[U: TypedEncoder](func: Iterator[T] => Iterator[U]): TypedDataset[U] = + @deprecated( + "deserialized methods have moved to a separate section to highlight their runtime overhead", + "0.4.0" + ) + def mapPartitions[U: TypedEncoder]( + func: Iterator[T] => Iterator[U] + ): TypedDataset[U] = deserialized.mapPartitions(func) - @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0") + @deprecated( + "deserialized methods have moved to a separate section to highlight their runtime overhead", + "0.4.0" + ) def flatMap[U: TypedEncoder](func: T => TraversableOnce[U]): TypedDataset[U] = deserialized.flatMap(func) - @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0") + @deprecated( + "deserialized methods have moved to a separate section to highlight their runtime overhead", + "0.4.0" + ) def filter(func: T => Boolean): TypedDataset[T] = deserialized.filter(func) - @deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4.0") + @deprecated( + "deserialized methods have moved to a separate section to highlight their runtime overhead", + "0.4.0" + ) def reduceOption[F[_]: SparkDelay](func: (T, T) => T): F[Option[T]] = deserialized.reduceOption(func) // $COVERAGE-ON$ - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - * - * @example The correct way to do a projection on a single column is to - * use the `select` method as follows: - * - * {{{ - * ds: TypedDataset[(String, String, String)] -> ds.select(ds('_2)).run() - * }}} - * - * Spark provides an alternative way to obtain the same resulting `Dataset`, - * using the `map` method: - * - * {{{ - * ds: TypedDataset[(String, String, String)] -> ds.deserialized.map(_._2).run() - * }}} - * - * This second approach is however substantially slower than the first one, - * and should be avoided as possible. Indeed, under the hood this `map` will - * deserialize the entire `Tuple3` to an full JVM object, call the apply - * method of the `_._2` closure on it, and serialize the resulting String back - * to its Catalyst representation. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + * + * @example The correct way to do a projection on a single column is to + * use the `select` method as follows: + * + * {{{ + * ds: TypedDataset[(String, String, String)] -> ds.select(ds('_2)).run() + * }}} + * + * Spark provides an alternative way to obtain the same resulting `Dataset`, + * using the `map` method: + * + * {{{ + * ds: TypedDataset[(String, String, String)] -> ds.deserialized.map(_._2).run() + * }}} + * + * This second approach is however substantially slower than the first one, + * and should be avoided as possible. Indeed, under the hood this `map` will + * deserialize the entire `Tuple3` to an full JVM object, call the apply + * method of the `_._2` closure on it, and serialize the resulting String back + * to its Catalyst representation. + */ object deserialized { - /** Returns a new [[TypedDataset]] that contains the result of applying `func` to each element. - * - * apache/spark - */ + + /** + * Returns a new [[TypedDataset]] that contains the result of applying `func` to each element. + * + * apache/spark + */ def map[U: TypedEncoder](func: T => U): TypedDataset[U] = TypedDataset.create(self.dataset.map(func)(TypedExpressionEncoder[U])) - /** Returns a new [[TypedDataset]] that contains the result of applying `func` to each partition. - * - * apache/spark - */ - def mapPartitions[U: TypedEncoder](func: Iterator[T] => Iterator[U]): TypedDataset[U] = - TypedDataset.create(self.dataset.mapPartitions(func)(TypedExpressionEncoder[U])) - - /** Returns a new [[TypedDataset]] by first applying a function to all elements of this [[TypedDataset]], - * and then flattening the results. - * - * apache/spark - */ - def flatMap[U: TypedEncoder](func: T => TraversableOnce[U]): TypedDataset[U] = + /** + * Returns a new [[TypedDataset]] that contains the result of applying `func` to each partition. + * + * apache/spark + */ + def mapPartitions[U: TypedEncoder]( + func: Iterator[T] => Iterator[U] + ): TypedDataset[U] = + TypedDataset.create( + self.dataset.mapPartitions(func)(TypedExpressionEncoder[U]) + ) + + /** + * Returns a new [[TypedDataset]] by first applying a function to all elements of this [[TypedDataset]], + * and then flattening the results. + * + * apache/spark + */ + def flatMap[U: TypedEncoder]( + func: T => TraversableOnce[U] + ): TypedDataset[U] = TypedDataset.create(self.dataset.flatMap(func)(TypedExpressionEncoder[U])) - /** Returns a new [[TypedDataset]] that only contains elements where `func` returns `true`. - * - * apache/spark - */ + /** + * Returns a new [[TypedDataset]] that only contains elements where `func` returns `true`. + * + * apache/spark + */ def filter(func: T => Boolean): TypedDataset[T] = TypedDataset.create(self.dataset.filter(func)) - /** Optionally reduces the elements of this [[TypedDataset]] using the specified binary function. The given - * `func` must be commutative and associative or the result may be non-deterministic. - * - * Differs from `Dataset#reduce` by wrapping its result into an `Option` and an effect-suspending `F`. - */ - def reduceOption[F[_]](func: (T, T) => T)(implicit F: SparkDelay[F]): F[Option[T]] = + /** + * Optionally reduces the elements of this [[TypedDataset]] using the specified binary function. The given + * `func` must be commutative and associative or the result may be non-deterministic. + * + * Differs from `Dataset#reduce` by wrapping its result into an `Option` and an effect-suspending `F`. + */ + def reduceOption[F[_]]( + func: (T, T) => T + )(implicit + F: SparkDelay[F] + ): F[Option[T]] = F.delay { try { Option(self.dataset.reduce(func)) diff --git a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala index 5b78cd292..121fd7fcc 100644 --- a/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedExpressionEncoder.scala @@ -3,18 +3,23 @@ package frameless import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{BoundReference, CreateNamedStruct, If} +import org.apache.spark.sql.catalyst.expressions.{ + BoundReference, + CreateNamedStruct, + If +} import org.apache.spark.sql.types.StructType object TypedExpressionEncoder { - /** In Spark, DataFrame has always schema of StructType - * - * DataFrames of primitive types become records - * with a single field called "value" set in ExpressionEncoder. - */ + /** + * In Spark, DataFrame has always schema of StructType + * + * DataFrames of primitive types become records + * with a single field called "value" set in ExpressionEncoder. + */ def targetStructType[A](encoder: TypedEncoder[A]): StructType = - encoder.catalystRepr match { + encoder.catalystRepr match { case x: StructType => if (encoder.nullable) StructType(x.fields.map(_.copy(nullable = true))) else x @@ -22,7 +27,10 @@ object TypedExpressionEncoder { case dt => new StructType().add("value", dt, nullable = encoder.nullable) } - def apply[T](implicit encoder: TypedEncoder[T]): Encoder[T] = { + def apply[T]( + implicit + encoder: TypedEncoder[T] + ): Encoder[T] = { val in = BoundReference(0, encoder.jvmRepr, encoder.nullable) val (out, serializer) = encoder.toCatalyst(in) match { @@ -46,4 +54,3 @@ object TypedExpressionEncoder { ) } } - diff --git a/dataset/src/main/scala/frameless/With.scala b/dataset/src/main/scala/frameless/With.scala index 11ceaa35b..85ce1d145 100644 --- a/dataset/src/main/scala/frameless/With.scala +++ b/dataset/src/main/scala/frameless/With.scala @@ -1,14 +1,15 @@ package frameless -/** Compute the intersection of two types: - * - * - With[A, A] = A - * - With[A, B] = A with B (when A != B) - * - * This type function is needed to prevent IDEs from infering large types - * with shape `A with A with ... with A`. These types could be confusing for - * both end users and IDE's type checkers. - */ +/** + * Compute the intersection of two types: + * + * - With[A, A] = A + * - With[A, B] = A with B (when A != B) + * + * This type function is needed to prevent IDEs from infering large types + * with shape `A with A with ... with A`. These types could be confusing for + * both end users and IDE's type checkers. + */ trait With[A, B] { type Out } object With extends LowPrioWith { diff --git a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala index e371ea048..61a8fdbab 100644 --- a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala @@ -3,71 +3,90 @@ package functions import org.apache.spark.sql.FramelessInternals.expr import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.{functions => sparkFunctions} +import org.apache.spark.sql.{ functions => sparkFunctions } import frameless.syntax._ import scala.annotation.nowarn trait AggregateFunctions { - /** Aggregate function: returns the number of items in a group. - * - * apache/spark - */ + + /** + * Aggregate function: returns the number of items in a group. + * + * apache/spark + */ def count[T](): TypedAggregate[T, Long] = sparkFunctions.count(sparkFunctions.lit(1)).typedAggregate - /** Aggregate function: returns the number of items in a group for which the selected column is not null. - * - * apache/spark - */ + /** + * Aggregate function: returns the number of items in a group for which the selected column is not null. + * + * apache/spark + */ def count[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = sparkFunctions.count(column.untyped).typedAggregate - /** Aggregate function: returns the number of distinct items in a group. - * - * apache/spark - */ + /** + * Aggregate function: returns the number of distinct items in a group. + * + * apache/spark + */ def countDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = sparkFunctions.countDistinct(column.untyped).typedAggregate - /** Aggregate function: returns the approximate number of distinct items in a group. - */ - def approxCountDistinct[T](column: TypedColumn[T, _]): TypedAggregate[T, Long] = + /** + * Aggregate function: returns the approximate number of distinct items in a group. + */ + def approxCountDistinct[T]( + column: TypedColumn[T, _] + ): TypedAggregate[T, Long] = sparkFunctions.approx_count_distinct(column.untyped).typedAggregate - /** Aggregate function: returns the approximate number of distinct items in a group. - * - * @param rsd maximum estimation error allowed (default = 0.05) - * - * apache/spark - */ - def approxCountDistinct[T](column: TypedColumn[T, _], rsd: Double): TypedAggregate[T, Long] = + /** + * Aggregate function: returns the approximate number of distinct items in a group. + * + * @param rsd maximum estimation error allowed (default = 0.05) + * + * apache/spark + */ + def approxCountDistinct[T]( + column: TypedColumn[T, _], + rsd: Double + ): TypedAggregate[T, Long] = sparkFunctions.approx_count_distinct(column.untyped, rsd).typedAggregate - /** Aggregate function: returns a list of objects with duplicates. - * - * apache/spark - */ - def collectList[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = + /** + * Aggregate function: returns a list of objects with duplicates. + * + * apache/spark + */ + def collectList[T, A: TypedEncoder]( + column: TypedColumn[T, A] + ): TypedAggregate[T, Vector[A]] = sparkFunctions.collect_list(column.untyped).typedAggregate - /** Aggregate function: returns a set of objects with duplicate elements eliminated. - * - * apache/spark - */ - def collectSet[T, A: TypedEncoder](column: TypedColumn[T, A]): TypedAggregate[T, Vector[A]] = + /** + * Aggregate function: returns a set of objects with duplicate elements eliminated. + * + * apache/spark + */ + def collectSet[T, A: TypedEncoder]( + column: TypedColumn[T, A] + ): TypedAggregate[T, Vector[A]] = sparkFunctions.collect_set(column.untyped).typedAggregate - /** Aggregate function: returns the sum of all values in the given column. - * - * apache/spark - */ - def sum[A, T, Out](column: TypedColumn[T, A])( - implicit - summable: CatalystSummable[A, Out], - oencoder: TypedEncoder[Out], - aencoder: TypedEncoder[A] - ): TypedAggregate[T, Out] = { + /** + * Aggregate function: returns the sum of all values in the given column. + * + * apache/spark + */ + def sum[A, T, Out]( + column: TypedColumn[T, A] + )(implicit + summable: CatalystSummable[A, Out], + oencoder: TypedEncoder[Out], + aencoder: TypedEncoder[A] + ): TypedAggregate[T, Out] = { val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr) val sumExpr = expr(sparkFunctions.sum(column.untyped)) val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr)) @@ -75,17 +94,19 @@ trait AggregateFunctions { new TypedAggregate[T, Out](sumOrZero) } - /** Aggregate function: returns the sum of distinct values in the column. - * - * apache/spark - */ + /** + * Aggregate function: returns the sum of distinct values in the column. + * + * apache/spark + */ @nowarn // supress sparkFunctions.sumDistinct call which is used to maintain Spark 3.1.x backwards compat - def sumDistinct[A, T, Out](column: TypedColumn[T, A])( - implicit - summable: CatalystSummable[A, Out], - oencoder: TypedEncoder[Out], - aencoder: TypedEncoder[A] - ): TypedAggregate[T, Out] = { + def sumDistinct[A, T, Out]( + column: TypedColumn[T, A] + )(implicit + summable: CatalystSummable[A, Out], + oencoder: TypedEncoder[Out], + aencoder: TypedEncoder[A] + ): TypedAggregate[T, Out] = { val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr) val sumExpr = expr(sparkFunctions.sumDistinct(column.untyped)) val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr)) @@ -93,186 +114,225 @@ trait AggregateFunctions { new TypedAggregate[T, Out](sumOrZero) } - /** Aggregate function: returns the average of the values in a group. - * - * apache/spark - */ - def avg[A, T, Out](column: TypedColumn[T, A])( - implicit - averageable: CatalystAverageable[A, Out], - oencoder: TypedEncoder[Out] - ): TypedAggregate[T, Out] = { + /** + * Aggregate function: returns the average of the values in a group. + * + * apache/spark + */ + def avg[A, T, Out]( + column: TypedColumn[T, A] + )(implicit + averageable: CatalystAverageable[A, Out], + oencoder: TypedEncoder[Out] + ): TypedAggregate[T, Out] = { new TypedAggregate[T, Out](sparkFunctions.avg(column.untyped)) } - /** Aggregate function: returns the unbiased variance of the values in a group. - * - * @note In Spark variance always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#186]] - * - * apache/spark - */ - def variance[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = + /** + * Aggregate function: returns the unbiased variance of the values in a group. + * + * @note In Spark variance always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#186]] + * + * apache/spark + */ + def variance[A: CatalystVariance, T]( + column: TypedColumn[T, A] + ): TypedAggregate[T, Double] = sparkFunctions.variance(column.untyped).typedAggregate - /** Aggregate function: returns the sample standard deviation. - * - * @note In Spark stddev always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#155]] - * - * apache/spark - */ - def stddev[A: CatalystVariance, T](column: TypedColumn[T, A]): TypedAggregate[T, Double] = + /** + * Aggregate function: returns the sample standard deviation. + * + * @note In Spark stddev always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#155]] + * + * apache/spark + */ + def stddev[A: CatalystVariance, T]( + column: TypedColumn[T, A] + ): TypedAggregate[T, Double] = sparkFunctions.stddev(column.untyped).typedAggregate /** - * Aggregate function: returns the standard deviation of a column by population. - * - * @note In Spark stddev always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L143]] - * - * apache/spark - */ - def stddevPop[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the standard deviation of a column by population. + * + * @note In Spark stddev always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L143]] + * + * apache/spark + */ + def stddevPop[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.stddev_pop(column.cast[Double].untyped) ) } /** - * Aggregate function: returns the standard deviation of a column by sample. - * - * @note In Spark stddev always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L160]] - * - * apache/spark - */ - def stddevSamp[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double] ): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the standard deviation of a column by sample. + * + * @note In Spark stddev always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L160]] + * + * apache/spark + */ + def stddevSamp[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.stddev_samp(column.cast[Double].untyped) ) } - /** Aggregate function: returns the maximum value of the column in a group. - * - * apache/spark - */ - def max[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { + /** + * Aggregate function: returns the maximum value of the column in a group. + * + * apache/spark + */ + def max[A: CatalystOrdered, T]( + column: TypedColumn[T, A] + ): TypedAggregate[T, A] = { implicit val c = column.uencoder sparkFunctions.max(column.untyped).typedAggregate } - /** Aggregate function: returns the minimum value of the column in a group. - * - * apache/spark - */ - def min[A: CatalystOrdered, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { + /** + * Aggregate function: returns the minimum value of the column in a group. + * + * apache/spark + */ + def min[A: CatalystOrdered, T]( + column: TypedColumn[T, A] + ): TypedAggregate[T, A] = { implicit val c = column.uencoder sparkFunctions.min(column.untyped).typedAggregate } - /** Aggregate function: returns the first value in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * apache/spark - */ + /** + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * apache/spark + */ def first[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { sparkFunctions.first(column.untyped).typedAggregate(column.uencoder) } /** - * Aggregate function: returns the last value in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * apache/spark - */ + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * apache/spark + */ def last[A, T](column: TypedColumn[T, A]): TypedAggregate[T, A] = { implicit val c = column.uencoder sparkFunctions.last(column.untyped).typedAggregate } /** - * Aggregate function: returns the Pearson Correlation Coefficient for two columns. - * - * @note In Spark corr always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala#L95]] - * - * apache/spark - */ - def corr[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B]) - (implicit + * Aggregate function: returns the Pearson Correlation Coefficient for two columns. + * + * @note In Spark corr always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala#L95]] + * + * apache/spark + */ + def corr[A, B, T]( + column1: TypedColumn[T, A], + column2: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Option[Double]] = { - new TypedAggregate[T, Option[Double]]( - sparkFunctions.corr(column1.cast[Double].untyped, column2.cast[Double].untyped) - ) - } + new TypedAggregate[T, Option[Double]]( + sparkFunctions + .corr(column1.cast[Double].untyped, column2.cast[Double].untyped) + ) + } /** - * Aggregate function: returns the covariance of two collumns. - * - * @note In Spark covar_pop always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L82]] - * - * apache/spark - */ - def covarPop[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B]) - (implicit + * Aggregate function: returns the covariance of two collumns. + * + * @note In Spark covar_pop always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L82]] + * + * apache/spark + */ + def covarPop[A, B, T]( + column1: TypedColumn[T, A], + column2: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Option[Double]] = { - new TypedAggregate[T, Option[Double]]( - sparkFunctions.covar_pop(column1.cast[Double].untyped, column2.cast[Double].untyped) - ) - } + new TypedAggregate[T, Option[Double]]( + sparkFunctions + .covar_pop(column1.cast[Double].untyped, column2.cast[Double].untyped) + ) + } /** - * Aggregate function: returns the covariance of two columns. - * - * @note In Spark covar_samp always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L93]] - * - * apache/spark - */ - def covarSamp[A, B, T](column1: TypedColumn[T, A], column2: TypedColumn[T, B]) - (implicit + * Aggregate function: returns the covariance of two columns. + * + * @note In Spark covar_samp always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala#L93]] + * + * apache/spark + */ + def covarSamp[A, B, T]( + column1: TypedColumn[T, A], + column2: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Option[Double]] = { - new TypedAggregate[T, Option[Double]]( - sparkFunctions.covar_samp(column1.cast[Double].untyped, column2.cast[Double].untyped) - ) - } - + new TypedAggregate[T, Option[Double]]( + sparkFunctions + .covar_samp(column1.cast[Double].untyped, column2.cast[Double].untyped) + ) + } /** - * Aggregate function: returns the kurtosis of a column. - * - * @note In Spark kurtosis always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L220]] - * - * apache/spark - */ - def kurtosis[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the kurtosis of a column. + * + * @note In Spark kurtosis always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L220]] + * + * apache/spark + */ + def kurtosis[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.kurtosis(column.cast[Double].untyped) ) } /** - * Aggregate function: returns the skewness of a column. - * - * @note In Spark skewness always returns Double - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L200]] - * - * apache/spark - */ - def skewness[A, T](column: TypedColumn[T, A])(implicit ev: CatalystCast[A, Double]): TypedAggregate[T, Option[Double]] = { + * Aggregate function: returns the skewness of a column. + * + * @note In Spark skewness always returns Double + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala#L200]] + * + * apache/spark + */ + def skewness[A, T]( + column: TypedColumn[T, A] + )(implicit + ev: CatalystCast[A, Double] + ): TypedAggregate[T, Option[Double]] = { new TypedAggregate[T, Option[Double]]( sparkFunctions.skewness(column.cast[Double].untyped) ) diff --git a/dataset/src/main/scala/frameless/functions/Lit.scala b/dataset/src/main/scala/frameless/functions/Lit.scala index d01467b13..78ee2fd3b 100644 --- a/dataset/src/main/scala/frameless/functions/Lit.scala +++ b/dataset/src/main/scala/frameless/functions/Lit.scala @@ -2,7 +2,10 @@ package frameless.functions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + NonSQLExpression +} import org.apache.spark.sql.types.DataType private[frameless] case class Lit[T <: AnyVal]( @@ -10,7 +13,8 @@ private[frameless] case class Lit[T <: AnyVal]( nullable: Boolean, show: () => String, catalystExpr: Expression // must be a generated Expression from a literal TypedEncoder's toCatalyst function -) extends Expression with NonSQLExpression { + ) extends Expression + with NonSQLExpression { override def toString: String = s"FramelessLit(${show()})" lazy val codegen = { @@ -52,12 +56,15 @@ private[frameless] case class Lit[T <: AnyVal]( } def eval(input: InternalRow): Any = codegen(input) - + def children: Seq[Expression] = Nil - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = catalystExpr.genCode(ctx) + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + catalystExpr.genCode(ctx) - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = this override val foldable: Boolean = catalystExpr.foldable } diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index 939bf5b8d..4ce3c63cf 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -1,537 +1,729 @@ package frameless package functions -import org.apache.spark.sql.{Column, functions => sparkFunctions} +import org.apache.spark.sql.{ Column, functions => sparkFunctions } import scala.annotation.nowarn import scala.util.matching.Regex trait NonAggregateFunctions { - /** Non-Aggregate function: calculates the SHA-2 digest of a binary column and returns the value as a 40 character hex string - * - * apache/spark - */ - def sha2[T](column: AbstractTypedColumn[T, Array[Byte]], numBits: Int): column.ThisType[T, String] = + + /** + * Non-Aggregate function: calculates the SHA-2 digest of a binary column and returns the value as a 40 character hex string + * + * apache/spark + */ + def sha2[T]( + column: AbstractTypedColumn[T, Array[Byte]], + numBits: Int + ): column.ThisType[T, String] = column.typed(sparkFunctions.sha2(column.untyped, numBits)) - /** Non-Aggregate function: calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex string - * - * apache/spark - */ - def sha1[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] = + /** + * Non-Aggregate function: calculates the SHA-1 digest of a binary column and returns the value as a 40 character hex string + * + * apache/spark + */ + def sha1[T]( + column: AbstractTypedColumn[T, Array[Byte]] + ): column.ThisType[T, String] = column.typed(sparkFunctions.sha1(column.untyped)) - /** Non-Aggregate function: returns a cyclic redundancy check value of a binary column as long. - * - * apache/spark - */ - def crc32[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, Long] = + /** + * Non-Aggregate function: returns a cyclic redundancy check value of a binary column as long. + * + * apache/spark + */ + def crc32[T]( + column: AbstractTypedColumn[T, Array[Byte]] + ): column.ThisType[T, Long] = column.typed(sparkFunctions.crc32(column.untyped)) + /** - * Non-Aggregate function: returns the negated value of column. - * - * apache/spark - */ - def negate[A, B, T](column: AbstractTypedColumn[T,A])( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], - i1: TypedEncoder[B] - ): column.ThisType[T,B] = + * Non-Aggregate function: returns the negated value of column. + * + * apache/spark + */ + def negate[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.negate(column.untyped)) /** - * Non-Aggregate function: logical not. - * - * apache/spark - */ - def not[T](column: AbstractTypedColumn[T,Boolean]): column.ThisType[T,Boolean] = + * Non-Aggregate function: logical not. + * + * apache/spark + */ + def not[T]( + column: AbstractTypedColumn[T, Boolean] + ): column.ThisType[T, Boolean] = column.typed(sparkFunctions.not(column.untyped)) /** - * Non-Aggregate function: Convert a number in a string column from one base to another. - * - * apache/spark - */ - def conv[T](column: AbstractTypedColumn[T,String], fromBase: Int, toBase: Int): column.ThisType[T,String] = - column.typed(sparkFunctions.conv(column.untyped,fromBase,toBase)) + * Non-Aggregate function: Convert a number in a string column from one base to another. + * + * apache/spark + */ + def conv[T]( + column: AbstractTypedColumn[T, String], + fromBase: Int, + toBase: Int + ): column.ThisType[T, String] = + column.typed(sparkFunctions.conv(column.untyped, fromBase, toBase)) - /** Non-Aggregate function: Converts an angle measured in radians to an approximately equivalent angle measured in degrees. - * - * apache/spark - */ - def degrees[A,T](column: AbstractTypedColumn[T,A]): column.ThisType[T,Double] = + /** + * Non-Aggregate function: Converts an angle measured in radians to an approximately equivalent angle measured in degrees. + * + * apache/spark + */ + def degrees[A, T]( + column: AbstractTypedColumn[T, A] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.degrees(column.untyped)) - /** Non-Aggregate function: returns the ceiling of a numeric column - * - * apache/spark - */ - def ceil[A, B, T](column: AbstractTypedColumn[T, A]) - (implicit + /** + * Non-Aggregate function: returns the ceiling of a numeric column + * + * apache/spark + */ + def ceil[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystRound[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = + column.typed(sparkFunctions.ceil(column.untyped))(i1) + + /** + * Non-Aggregate function: returns the floor of a numeric column + * + * apache/spark + */ + def floor[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit i0: CatalystRound[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.ceil(column.untyped))(i1) - - /** Non-Aggregate function: returns the floor of a numeric column - * - * apache/spark - */ - def floor[A, B, T](column: AbstractTypedColumn[T, A]) - (implicit - i0: CatalystRound[A, B], - i1: TypedEncoder[B] - ): column.ThisType[T, B] = column.typed(sparkFunctions.floor(column.untyped))(i1) - /** Non-Aggregate function: unsigned shift the the given value numBits right. If given long, will return long else it will return an integer. - * - * apache/spark - */ + /** + * Non-Aggregate function: unsigned shift the the given value numBits right. If given long, will return long else it will return an integer. + * + * apache/spark + */ @nowarn // supress sparkFunctions.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat - def shiftRightUnsigned[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) - (implicit + def shiftRightUnsigned[A, B, T]( + column: AbstractTypedColumn[T, A], + numBits: Int + )(implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits)) + column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits)) - /** Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer. - * - * apache/spark - */ + /** + * Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer. + * + * apache/spark + */ @nowarn // supress sparkFunctions.shiftReft call which is used to maintain Spark 3.1.x backwards compat - def shiftRight[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) - (implicit + def shiftRight[A, B, T]( + column: AbstractTypedColumn[T, A], + numBits: Int + )(implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftRight(column.untyped, numBits)) + column.typed(sparkFunctions.shiftRight(column.untyped, numBits)) - /** Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer. - * - * apache/spark - */ + /** + * Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer. + * + * apache/spark + */ @nowarn // supress sparkFunctions.shiftLeft call which is used to maintain Spark 3.1.x backwards compat - def shiftLeft[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) - (implicit + def shiftLeft[A, B, T]( + column: AbstractTypedColumn[T, A], + numBits: Int + )(implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = column.typed(sparkFunctions.shiftLeft(column.untyped, numBits)) - - /** Non-Aggregate function: returns the absolute value of a numeric column - * - * apache/spark - */ - def abs[A, B, T](column: AbstractTypedColumn[T, A]) - (implicit - i0: CatalystNumericWithJavaBigDecimal[A, B], - i1: TypedEncoder[B] + + /** + * Non-Aggregate function: returns the absolute value of a numeric column + * + * apache/spark + */ + def abs[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.abs(column.untyped))(i1) - - /** Non-Aggregate function: Computes the cosine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def cos[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.cos(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the hyperbolic cosine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def cosh[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.cosh(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the signum of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def signum[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + column.typed(sparkFunctions.abs(column.untyped))(i1) + + /** + * Non-Aggregate function: Computes the cosine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def cos[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.cos(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the hyperbolic cosine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def cosh[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.cosh(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the signum of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def signum[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.signum(column.cast[Double].untyped)) - /** Non-Aggregate function: Computes the sine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def sin[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.sin(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the hyperbolic sine of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def sinh[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.sinh(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the tangent of the given column. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def tan[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.tan(column.cast[Double].untyped)) - - /** Non-Aggregate function: Computes the hyperbolic tangent of the given value. - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def tanh[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.tanh(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns the acos of a numeric column - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def acos[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.acos(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns true if value is contained with in the array in the specified column - * - * apache/spark - */ - def arrayContains[C[_]: CatalystCollection, A, T](column: AbstractTypedColumn[T, C[A]], value: A): column.ThisType[T, Boolean] = + /** + * Non-Aggregate function: Computes the sine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def sin[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.sin(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the hyperbolic sine of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def sinh[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.sinh(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the tangent of the given column. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def tan[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.tan(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: Computes the hyperbolic tangent of the given value. + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def tanh[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.tanh(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns the acos of a numeric column + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def acos[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.acos(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns true if value is contained with in the array in the specified column + * + * apache/spark + */ + def arrayContains[C[_]: CatalystCollection, A, T]( + column: AbstractTypedColumn[T, C[A]], + value: A + ): column.ThisType[T, Boolean] = column.typed(sparkFunctions.array_contains(column.untyped, value)) - /** Non-Aggregate function: returns the atan of a numeric column - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def atan[A, T](column: AbstractTypedColumn[T,A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.atan(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns the asin of a numeric column - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def asin[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = - column.typed(sparkFunctions.asin(column.cast[Double].untyped)) - - /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def atan2[A, B, T](l: TypedColumn[T, A], r: TypedColumn[T, B]) - (implicit + /** + * Non-Aggregate function: returns the atan of a numeric column + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def atan[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.atan(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns the asin of a numeric column + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def asin[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = + column.typed(sparkFunctions.asin(column.cast[Double].untyped)) + + /** + * Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def atan2[A, B, T]( + l: TypedColumn[T, A], + r: TypedColumn[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedColumn[T, Double] = - r.typed(sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped)) - - /** Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to - * polar coordinates (r, theta). - * - * Spark will expect a Double value for this expression. See: - * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] - * apache/spark - */ - def atan2[A, B, T](l: TypedAggregate[T, A], r: TypedAggregate[T, B]) - (implicit + r.typed( + sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped) + ) + + /** + * Non-Aggregate function: returns the angle theta from the conversion of rectangular coordinates (x, y) to + * polar coordinates (r, theta). + * + * Spark will expect a Double value for this expression. See: + * [[https://github.com/apache/spark/blob/4a3c09601ba69f7d49d1946bb6f20f5cfe453031/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala#L67]] + * apache/spark + */ + def atan2[A, B, T]( + l: TypedAggregate[T, A], + r: TypedAggregate[T, B] + )(implicit i0: CatalystCast[A, Double], i1: CatalystCast[B, Double] ): TypedAggregate[T, Double] = - r.typed(sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped)) - - def atan2[B, T](l: Double, r: TypedColumn[T, B]) - (implicit i0: CatalystCast[B, Double]): TypedColumn[T, Double] = - atan2(r.lit(l), r) - - def atan2[A, T](l: TypedColumn[T, A], r: Double) - (implicit i0: CatalystCast[A, Double]): TypedColumn[T, Double] = - atan2(l, l.lit(r)) - - def atan2[B, T](l: Double, r: TypedAggregate[T, B]) - (implicit i0: CatalystCast[B, Double]): TypedAggregate[T, Double] = - atan2(r.lit(l), r) - - def atan2[A, T](l: TypedAggregate[T, A], r: Double) - (implicit i0: CatalystCast[A, Double]): TypedAggregate[T, Double] = - atan2(l, l.lit(r)) - - /** Non-Aggregate function: returns the square root value of a numeric column. - * - * apache/spark - */ - def sqrt[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + r.typed( + sparkFunctions.atan2(l.cast[Double].untyped, r.cast[Double].untyped) + ) + + def atan2[B, T]( + l: Double, + r: TypedColumn[T, B] + )(implicit + i0: CatalystCast[B, Double] + ): TypedColumn[T, Double] = + atan2(r.lit(l), r) + + def atan2[A, T]( + l: TypedColumn[T, A], + r: Double + )(implicit + i0: CatalystCast[A, Double] + ): TypedColumn[T, Double] = + atan2(l, l.lit(r)) + + def atan2[B, T]( + l: Double, + r: TypedAggregate[T, B] + )(implicit + i0: CatalystCast[B, Double] + ): TypedAggregate[T, Double] = + atan2(r.lit(l), r) + + def atan2[A, T]( + l: TypedAggregate[T, A], + r: Double + )(implicit + i0: CatalystCast[A, Double] + ): TypedAggregate[T, Double] = + atan2(l, l.lit(r)) + + /** + * Non-Aggregate function: returns the square root value of a numeric column. + * + * apache/spark + */ + def sqrt[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.sqrt(column.cast[Double].untyped)) - /** Non-Aggregate function: returns the cubic root value of a numeric column. - * - * apache/spark - */ - def cbrt[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + /** + * Non-Aggregate function: returns the cubic root value of a numeric column. + * + * apache/spark + */ + def cbrt[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.cbrt(column.cast[Double].untyped)) - /** Non-Aggregate function: returns the exponential value of a numeric column. - * - * apache/spark - */ - def exp[A, T](column: AbstractTypedColumn[T, A]) - (implicit i0: CatalystCast[A, Double]): column.ThisType[T, Double] = + /** + * Non-Aggregate function: returns the exponential value of a numeric column. + * + * apache/spark + */ + def exp[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.exp(column.cast[Double].untyped)) - /** Non-Aggregate function: Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. - * - * apache/spark - */ - def round[A, B, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Returns the value of the column `e` rounded to 0 decimal places with HALF_UP round mode. + * + * apache/spark + */ + def round[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.round(column.untyped))(i1) - /** Non-Aggregate function: Round the value of `e` to `scale` decimal places with HALF_UP round mode - * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. - * - * apache/spark - */ - def round[A, B, T](column: AbstractTypedColumn[T, A], scale: Int)( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Round the value of `e` to `scale` decimal places with HALF_UP round mode + * if `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * + * apache/spark + */ + def round[A, B, T]( + column: AbstractTypedColumn[T, A], + scale: Int + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.round(column.untyped, scale))(i1) - /** Non-Aggregate function: Bankers Rounding - returns the rounded to 0 decimal places value with HALF_EVEN round mode - * of a numeric column. - * - * apache/spark - */ - def bround[A, B, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Bankers Rounding - returns the rounded to 0 decimal places value with HALF_EVEN round mode + * of a numeric column. + * + * apache/spark + */ + def bround[A, B, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.bround(column.untyped))(i1) - /** Non-Aggregate function: Bankers Rounding - returns the rounded to `scale` decimal places value with HALF_EVEN round mode - * of a numeric column. If `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. - * - * apache/spark - */ - def bround[A, B, T](column: AbstractTypedColumn[T, A], scale: Int)( - implicit i0: CatalystNumericWithJavaBigDecimal[A, B], i1: TypedEncoder[B] - ): column.ThisType[T, B] = + /** + * Non-Aggregate function: Bankers Rounding - returns the rounded to `scale` decimal places value with HALF_EVEN round mode + * of a numeric column. If `scale` is greater than or equal to 0 or at integral part when `scale` is less than 0. + * + * apache/spark + */ + def bround[A, B, T]( + column: AbstractTypedColumn[T, A], + scale: Int + )(implicit + i0: CatalystNumericWithJavaBigDecimal[A, B], + i1: TypedEncoder[B] + ): column.ThisType[T, B] = column.typed(sparkFunctions.bround(column.untyped, scale))(i1) /** - * Computes the natural logarithm of the given value. - * - * apache/spark - */ - def log[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the natural logarithm of the given value. + * + * apache/spark + */ + def log[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log(column.untyped)) /** - * Returns the first argument-base logarithm of the second argument. - * - * apache/spark - */ - def log[A, T](base: Double, column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the first argument-base logarithm of the second argument. + * + * apache/spark + */ + def log[A, T]( + base: Double, + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log(base, column.untyped)) /** - * Computes the logarithm of the given column in base 2. - * - * apache/spark - */ - def log2[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the logarithm of the given column in base 2. + * + * apache/spark + */ + def log2[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log2(column.untyped)) /** - * Computes the natural logarithm of the given value plus one. - * - * apache/spark - */ - def log1p[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the natural logarithm of the given value plus one. + * + * apache/spark + */ + def log1p[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log1p(column.untyped)) /** - * Computes the logarithm of the given column in base 10. - * - * apache/spark - */ - def log10[A, T](column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes the logarithm of the given column in base 10. + * + * apache/spark + */ + def log10[A, T]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.log10(column.untyped)) - /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * apache/spark - */ - def hypot[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * apache/spark + */ + def hypot[A, T]( + column: AbstractTypedColumn[T, A], + column2: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.hypot(column.untyped, column2.untyped)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * apache/spark - */ - def hypot[A, T](column: AbstractTypedColumn[T, A], l: Double)( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * apache/spark + */ + def hypot[A, T]( + column: AbstractTypedColumn[T, A], + l: Double + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.hypot(column.untyped, l)) /** - * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. - * - * apache/spark - */ - def hypot[A, T](l: Double, column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. + * + * apache/spark + */ + def hypot[A, T]( + l: Double, + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.hypot(l, column.untyped)) /** - * Returns the value of the first argument raised to the power of the second argument. - * - * apache/spark - */ - def pow[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the value of the first argument raised to the power of the second argument. + * + * apache/spark + */ + def pow[A, T]( + column: AbstractTypedColumn[T, A], + column2: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.pow(column.untyped, column2.untyped)) /** - * Returns the value of the first argument raised to the power of the second argument. - * - * apache/spark - */ - def pow[A, T](column: AbstractTypedColumn[T, A], l: Double)( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the value of the first argument raised to the power of the second argument. + * + * apache/spark + */ + def pow[A, T]( + column: AbstractTypedColumn[T, A], + l: Double + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.pow(column.untyped, l)) /** - * Returns the value of the first argument raised to the power of the second argument. - * - * apache/spark - */ - def pow[A, T](l: Double, column: AbstractTypedColumn[T, A])( - implicit i0: CatalystCast[A, Double] - ): column.ThisType[T, Double] = + * Returns the value of the first argument raised to the power of the second argument. + * + * apache/spark + */ + def pow[A, T]( + l: Double, + column: AbstractTypedColumn[T, A] + )(implicit + i0: CatalystCast[A, Double] + ): column.ThisType[T, Double] = column.typed(sparkFunctions.pow(l, column.untyped)) /** - * Returns the positive value of dividend mod divisor. - * - * apache/spark - */ - def pmod[A, T](column: AbstractTypedColumn[T, A], column2: AbstractTypedColumn[T, A])( - implicit i0: TypedEncoder[A] - ): column.ThisType[T, A] = + * Returns the positive value of dividend mod divisor. + * + * apache/spark + */ + def pmod[A, T]( + column: AbstractTypedColumn[T, A], + column2: AbstractTypedColumn[T, A] + )(implicit + i0: TypedEncoder[A] + ): column.ThisType[T, A] = column.typed(sparkFunctions.pmod(column.untyped, column2.untyped)) - - /** Non-Aggregate function: Returns the string representation of the binary value of the given long - * column. For example, bin("12") returns "1100". - * - * apache/spark - */ + /** + * Non-Aggregate function: Returns the string representation of the binary value of the given long + * column. For example, bin("12") returns "1100". + * + * apache/spark + */ def bin[T](column: AbstractTypedColumn[T, Long]): column.ThisType[T, String] = column.typed(sparkFunctions.bin(column.untyped)) /** - * Calculates the MD5 digest of a binary column and returns the value - * as a 32 character hex string. - * - * apache/spark - */ - def md5[T, A](column: AbstractTypedColumn[T, A])(implicit i0: TypedEncoder[A]): column.ThisType[T, String] = + * Calculates the MD5 digest of a binary column and returns the value + * as a 32 character hex string. + * + * apache/spark + */ + def md5[T, A]( + column: AbstractTypedColumn[T, A] + )(implicit + i0: TypedEncoder[A] + ): column.ThisType[T, String] = column.typed(sparkFunctions.md5(column.untyped)) /** - * Computes the factorial of the given value. - * - * apache/spark - */ - def factorial[T](column: AbstractTypedColumn[T, Long])(implicit i0: TypedEncoder[Long]): column.ThisType[T, Long] = + * Computes the factorial of the given value. + * + * apache/spark + */ + def factorial[T]( + column: AbstractTypedColumn[T, Long] + )(implicit + i0: TypedEncoder[Long] + ): column.ThisType[T, Long] = column.typed(sparkFunctions.factorial(column.untyped)) - /** Non-Aggregate function: Computes bitwise NOT. - * - * apache/spark - */ + /** + * Non-Aggregate function: Computes bitwise NOT. + * + * apache/spark + */ @nowarn // supress sparkFunctions.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat - def bitwiseNOT[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] = + def bitwiseNOT[A: CatalystBitwise, T]( + column: AbstractTypedColumn[T, A] + ): column.ThisType[T, A] = column.typed(sparkFunctions.bitwiseNOT(column.untyped))(column.uencoder) - /** Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from - * a file - * - * apache/spark - */ + /** + * Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from + * a file + * + * apache/spark + */ def inputFileName[T](): TypedColumn[T, String] = new TypedColumn[T, String](sparkFunctions.input_file_name()) - /** Non-Aggregate function: generates monotonically increasing id - * - * apache/spark - */ + /** + * Non-Aggregate function: generates monotonically increasing id + * + * apache/spark + */ def monotonicallyIncreasingId[T](): TypedColumn[T, Long] = { new TypedColumn[T, Long](sparkFunctions.monotonically_increasing_id()) } - /** Non-Aggregate function: Evaluates a list of conditions and returns one of multiple - * possible result expressions. If none match, otherwise is returned - * {{{ - * when(ds('boolField), ds('a)) - * .when(ds('otherBoolField), lit(123)) - * .otherwise(ds('b)) - * }}} - * apache/spark - */ - def when[T, A](condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] = + /** + * Non-Aggregate function: Evaluates a list of conditions and returns one of multiple + * possible result expressions. If none match, otherwise is returned + * {{{ + * when(ds('boolField), ds('a)) + * .when(ds('otherBoolField), lit(123)) + * .otherwise(ds('b)) + * }}} + * apache/spark + */ + def when[T, A]( + condition: AbstractTypedColumn[T, Boolean], + value: AbstractTypedColumn[T, A] + ): When[T, A] = new When[T, A](condition, value) class When[T, A] private (untypedC: Column) { - private[functions] def this(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]) = + private[functions] def this( + condition: AbstractTypedColumn[T, Boolean], + value: AbstractTypedColumn[T, A] + ) = this(sparkFunctions.when(condition.untyped, value.untyped)) - def when(condition: AbstractTypedColumn[T, Boolean], value: AbstractTypedColumn[T, A]): When[T, A] = + def when( + condition: AbstractTypedColumn[T, Boolean], + value: AbstractTypedColumn[T, A] + ): When[T, A] = new When[T, A](untypedC.when(condition.untyped, value.untyped)) def otherwise(value: AbstractTypedColumn[T, A]): value.ThisType[T, A] = @@ -542,172 +734,230 @@ trait NonAggregateFunctions { // String functions ////////////////////////////////////////////////////////////////////////////////////////////// - - /** Non-Aggregate function: takes the first letter of a string column and returns the ascii int value in a new column - * - * apache/spark - */ - def ascii[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Int] = + /** + * Non-Aggregate function: takes the first letter of a string column and returns the ascii int value in a new column + * + * apache/spark + */ + def ascii[T]( + column: AbstractTypedColumn[T, String] + ): column.ThisType[T, Int] = column.typed(sparkFunctions.ascii(column.untyped)) - /** Non-Aggregate function: Computes the BASE64 encoding of a binary column and returns it as a string column. - * This is the reverse of unbase64. - * - * apache/spark - */ - def base64[T](column: AbstractTypedColumn[T, Array[Byte]]): column.ThisType[T, String] = + /** + * Non-Aggregate function: Computes the BASE64 encoding of a binary column and returns it as a string column. + * This is the reverse of unbase64. + * + * apache/spark + */ + def base64[T]( + column: AbstractTypedColumn[T, Array[Byte]] + ): column.ThisType[T, String] = column.typed(sparkFunctions.base64(column.untyped)) - /** Non-Aggregate function: Decodes a BASE64 encoded string column and returns it as a binary column. - * This is the reverse of base64. - * - * apache/spark - */ - def unbase64[T](column: AbstractTypedColumn[T, String]): column.ThisType[T, Array[Byte]] = + /** + * Non-Aggregate function: Decodes a BASE64 encoded string column and returns it as a binary column. + * This is the reverse of base64. + * + * apache/spark + */ + def unbase64[T]( + column: AbstractTypedColumn[T, String] + ): column.ThisType[T, Array[Byte]] = column.typed(sparkFunctions.unbase64(column.untyped)) - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ def concat[T](columns: TypedColumn[T, String]*): TypedColumn[T, String] = new TypedColumn(sparkFunctions.concat(columns.map(_.untyped): _*)) - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ - def concat[T](columns: TypedAggregate[T, String]*): TypedAggregate[T, String] = + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ + def concat[T]( + columns: TypedAggregate[T, String]* + ): TypedAggregate[T, String] = new TypedAggregate(sparkFunctions.concat(columns.map(_.untyped): _*)) - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column, - * using the given separator. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ - def concatWs[T](sep: String, columns: TypedAggregate[T, String]*): TypedAggregate[T, String] = - new TypedAggregate(sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*)) - - /** Non-Aggregate function: Concatenates multiple input string columns together into a single string column, - * using the given separator. - * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] - * - * apache/spark - */ - def concatWs[T](sep: String, columns: TypedColumn[T, String]*): TypedColumn[T, String] = + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column, + * using the given separator. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ + def concatWs[T]( + sep: String, + columns: TypedAggregate[T, String]* + ): TypedAggregate[T, String] = + new TypedAggregate( + sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*) + ) + + /** + * Non-Aggregate function: Concatenates multiple input string columns together into a single string column, + * using the given separator. + * @note varargs make it harder to generalize so we overload the method for [[TypedColumn]] and [[TypedAggregate]] + * + * apache/spark + */ + def concatWs[T]( + sep: String, + columns: TypedColumn[T, String]* + ): TypedColumn[T, String] = new TypedColumn(sparkFunctions.concat_ws(sep, columns.map(_.untyped): _*)) - /** Non-Aggregate function: Locates the position of the first occurrence of substring column - * in given string - * - * @note The position is not zero based, but 1 based index. Returns 0 if substr - * could not be found in str. - * - * apache/spark - */ - def instr[T](str: AbstractTypedColumn[T, String], substring: String): str.ThisType[T, Int] = + /** + * Non-Aggregate function: Locates the position of the first occurrence of substring column + * in given string + * + * @note The position is not zero based, but 1 based index. Returns 0 if substr + * could not be found in str. + * + * apache/spark + */ + def instr[T]( + str: AbstractTypedColumn[T, String], + substring: String + ): str.ThisType[T, Int] = str.typed(sparkFunctions.instr(str.untyped, substring)) - /** Non-Aggregate function: Computes the length of a given string. - * - * apache/spark - */ - //TODO: Also for binary + /** + * Non-Aggregate function: Computes the length of a given string. + * + * apache/spark + */ + // TODO: Also for binary def length[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Int] = str.typed(sparkFunctions.length(str.untyped)) - /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. - * - * apache/spark - */ - def levenshtein[T](l: TypedColumn[T, String], r: TypedColumn[T, String]): TypedColumn[T, Int] = + /** + * Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. + * + * apache/spark + */ + def levenshtein[T]( + l: TypedColumn[T, String], + r: TypedColumn[T, String] + ): TypedColumn[T, Int] = l.typed(sparkFunctions.levenshtein(l.untyped, r.untyped)) - /** Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. - * - * apache/spark - */ - def levenshtein[T](l: TypedAggregate[T, String], r: TypedAggregate[T, String]): TypedAggregate[T, Int] = + /** + * Non-Aggregate function: Computes the Levenshtein distance of the two given string columns. + * + * apache/spark + */ + def levenshtein[T]( + l: TypedAggregate[T, String], + r: TypedAggregate[T, String] + ): TypedAggregate[T, Int] = l.typed(sparkFunctions.levenshtein(l.untyped, r.untyped)) - /** Non-Aggregate function: Converts a string column to lower case. - * - * apache/spark - */ + /** + * Non-Aggregate function: Converts a string column to lower case. + * + * apache/spark + */ def lower[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.lower(str.untyped)) - /** Non-Aggregate function: Left-pad the string column with pad to a length of len. If the string column is longer - * than len, the return value is shortened to len characters. - * - * apache/spark - */ - def lpad[T](str: AbstractTypedColumn[T, String], - len: Int, - pad: String): str.ThisType[T, String] = + /** + * Non-Aggregate function: Left-pad the string column with pad to a length of len. If the string column is longer + * than len, the return value is shortened to len characters. + * + * apache/spark + */ + def lpad[T]( + str: AbstractTypedColumn[T, String], + len: Int, + pad: String + ): str.ThisType[T, String] = str.typed(sparkFunctions.lpad(str.untyped, len, pad)) - /** Non-Aggregate function: Trim the spaces from left end for the specified string value. - * - * apache/spark - */ + /** + * Non-Aggregate function: Trim the spaces from left end for the specified string value. + * + * apache/spark + */ def ltrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.ltrim(str.untyped)) - /** Non-Aggregate function: Replace all substrings of the specified string value that match regexp with rep. - * - * apache/spark - */ - def regexpReplace[T](str: AbstractTypedColumn[T, String], - pattern: Regex, - replacement: String): str.ThisType[T, String] = - str.typed(sparkFunctions.regexp_replace(str.untyped, pattern.regex, replacement)) - + /** + * Non-Aggregate function: Replace all substrings of the specified string value that match regexp with rep. + * + * apache/spark + */ + def regexpReplace[T]( + str: AbstractTypedColumn[T, String], + pattern: Regex, + replacement: String + ): str.ThisType[T, String] = + str.typed( + sparkFunctions.regexp_replace(str.untyped, pattern.regex, replacement) + ) - /** Non-Aggregate function: Reverses the string column and returns it as a new string column. - * - * apache/spark - */ + /** + * Non-Aggregate function: Reverses the string column and returns it as a new string column. + * + * apache/spark + */ def reverse[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.reverse(str.untyped)) - /** Non-Aggregate function: Right-pad the string column with pad to a length of len. - * If the string column is longer than len, the return value is shortened to len characters. - * - * apache/spark - */ - def rpad[T](str: AbstractTypedColumn[T, String], len: Int, pad: String): str.ThisType[T, String] = + /** + * Non-Aggregate function: Right-pad the string column with pad to a length of len. + * If the string column is longer than len, the return value is shortened to len characters. + * + * apache/spark + */ + def rpad[T]( + str: AbstractTypedColumn[T, String], + len: Int, + pad: String + ): str.ThisType[T, String] = str.typed(sparkFunctions.rpad(str.untyped, len, pad)) - /** Non-Aggregate function: Trim the spaces from right end for the specified string value. - * - * apache/spark - */ + /** + * Non-Aggregate function: Trim the spaces from right end for the specified string value. + * + * apache/spark + */ def rtrim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.rtrim(str.untyped)) - /** Non-Aggregate function: Substring starts at `pos` and is of length `len` - * - * apache/spark - */ - //TODO: Also for byte array - def substring[T](str: AbstractTypedColumn[T, String], pos: Int, len: Int): str.ThisType[T, String] = + /** + * Non-Aggregate function: Substring starts at `pos` and is of length `len` + * + * apache/spark + */ + // TODO: Also for byte array + def substring[T]( + str: AbstractTypedColumn[T, String], + pos: Int, + len: Int + ): str.ThisType[T, String] = str.typed(sparkFunctions.substring(str.untyped, pos, len)) - /** Non-Aggregate function: Trim the spaces from both ends for the specified string column. - * - * apache/spark - */ + /** + * Non-Aggregate function: Trim the spaces from both ends for the specified string column. + * + * apache/spark + */ def trim[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.trim(str.untyped)) - /** Non-Aggregate function: Converts a string column to upper case. - * - * apache/spark - */ + /** + * Non-Aggregate function: Converts a string column to upper case. + * + * apache/spark + */ def upper[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, String] = str.typed(sparkFunctions.upper(str.untyped)) @@ -715,93 +965,123 @@ trait NonAggregateFunctions { // DateTime functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string. - * - * Differs from `Column#year` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def year[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the year as an integer from a given date/timestamp/string. + * + * Differs from `Column#year` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def year[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.year(str.untyped)) - /** Non-Aggregate function: Extracts the quarter as an integer from a given date/timestamp/string. - * - * Differs from `Column#quarter` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def quarter[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the quarter as an integer from a given date/timestamp/string. + * + * Differs from `Column#quarter` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def quarter[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.quarter(str.untyped)) - /** Non-Aggregate function Extracts the month as an integer from a given date/timestamp/string. - * - * Differs from `Column#month` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def month[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function Extracts the month as an integer from a given date/timestamp/string. + * + * Differs from `Column#month` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def month[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.month(str.untyped)) - /** Non-Aggregate function: Extracts the day of the week as an integer from a given date/timestamp/string. - * - * Differs from `Column#dayofweek` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def dayofweek[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the day of the week as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofweek` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def dayofweek[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.dayofweek(str.untyped)) - /** Non-Aggregate function: Extracts the day of the month as an integer from a given date/timestamp/string. - * - * Differs from `Column#dayofmonth` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def dayofmonth[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the day of the month as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofmonth` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def dayofmonth[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.dayofmonth(str.untyped)) - /** Non-Aggregate function: Extracts the day of the year as an integer from a given date/timestamp/string. - * - * Differs from `Column#dayofyear` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def dayofyear[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the day of the year as an integer from a given date/timestamp/string. + * + * Differs from `Column#dayofyear` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def dayofyear[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.dayofyear(str.untyped)) - /** Non-Aggregate function: Extracts the hours as an integer from a given date/timestamp/string. - * - * Differs from `Column#hour` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def hour[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the hours as an integer from a given date/timestamp/string. + * + * Differs from `Column#hour` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def hour[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.hour(str.untyped)) - /** Non-Aggregate function: Extracts the minutes as an integer from a given date/timestamp/string. - * - * Differs from `Column#minute` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def minute[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the minutes as an integer from a given date/timestamp/string. + * + * Differs from `Column#minute` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def minute[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.minute(str.untyped)) - /** Non-Aggregate function: Extracts the seconds as an integer from a given date/timestamp/string. - * - * Differs from `Column#second` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def second[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the seconds as an integer from a given date/timestamp/string. + * + * Differs from `Column#second` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def second[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.second(str.untyped)) - /** Non-Aggregate function: Extracts the week number as an integer from a given date/timestamp/string. - * - * Differs from `Column#weekofyear` by wrapping it's result into an `Option`. - * - * apache/spark - */ - def weekofyear[T](str: AbstractTypedColumn[T, String]): str.ThisType[T, Option[Int]] = + /** + * Non-Aggregate function: Extracts the week number as an integer from a given date/timestamp/string. + * + * Differs from `Column#weekofyear` by wrapping it's result into an `Option`. + * + * apache/spark + */ + def weekofyear[T]( + str: AbstractTypedColumn[T, String] + ): str.ThisType[T, Option[Int]] = str.typed(sparkFunctions.weekofyear(str.untyped)) } diff --git a/dataset/src/main/scala/frameless/functions/Udf.scala b/dataset/src/main/scala/frameless/functions/Udf.scala index 93ba7f118..42d65e8ca 100644 --- a/dataset/src/main/scala/frameless/functions/Udf.scala +++ b/dataset/src/main/scala/frameless/functions/Udf.scala @@ -2,90 +2,118 @@ package frameless package functions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, NonSQLExpression} +import org.apache.spark.sql.catalyst.expressions.{ + Expression, + LeafExpression, + NonSQLExpression +} import org.apache.spark.sql.catalyst.expressions.codegen._ import Block._ import org.apache.spark.sql.types.DataType import shapeless.syntax.std.tuple._ -/** Documentation marked "apache/spark" is thanks to apache/spark Contributors - * at https://github.com/apache/spark, licensed under Apache v2.0 available at - * http://www.apache.org/licenses/LICENSE-2.0 - */ +/** + * Documentation marked "apache/spark" is thanks to apache/spark Contributors + * at https://github.com/apache/spark, licensed under Apache v2.0 available at + * http://www.apache.org/licenses/LICENSE-2.0 + */ trait Udf { - /** Defines a user-defined function of 1 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A, R: TypedEncoder](f: A => R): - TypedColumn[T, A] => TypedColumn[T, R] = { - u => - val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R]) - new TypedColumn[T, R](scalaUdf) + /** + * Defines a user-defined function of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A, R: TypedEncoder]( + f: A => R + ): TypedColumn[T, A] => TypedColumn[T, R] = { u => + val scalaUdf = FramelessUdf(f, List(u), TypedEncoder[R]) + new TypedColumn[T, R](scalaUdf) } - /** Defines a user-defined function of 2 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, R: TypedEncoder](f: (A1,A2) => R): - (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, R: TypedEncoder]( + f: (A1, A2) => R + ): (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 3 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, R: TypedEncoder](f: (A1,A2,A3) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, R: TypedEncoder]( + f: (A1, A2, A3) => R + ): (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 4 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1,A2,A3,A4) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, R: TypedEncoder](f: (A1, A2, A3, A4) => R): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3], + TypedColumn[T, A4] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } - /** Defines a user-defined function of 5 arguments as user-defined function (UDF). - * The data types are automatically inferred based on the function's signature. - * - * apache/spark - */ - def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder](f: (A1,A2,A3,A4,A5) => R): - (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = { + /** + * Defines a user-defined function of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the function's signature. + * + * apache/spark + */ + def udf[T, A1, A2, A3, A4, A5, R: TypedEncoder]( + f: (A1, A2, A3, A4, A5) => R + ): ( + TypedColumn[T, A1], + TypedColumn[T, A2], + TypedColumn[T, A3], + TypedColumn[T, A4], + TypedColumn[T, A5] + ) => TypedColumn[T, R] = { case us => - val scalaUdf = FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) + val scalaUdf = + FramelessUdf(f, us.toList[UntypedExpression[T]], TypedEncoder[R]) new TypedColumn[T, R](scalaUdf) - } + } } /** - * NB: Implementation detail, isn't intended to be directly used. - * - * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. - */ + * NB: Implementation detail, isn't intended to be directly used. + * + * Our own implementation of `ScalaUDF` from Catalyst compatible with [[TypedEncoder]]. + */ case class FramelessUdf[T, R]( - function: AnyRef, - encoders: Seq[TypedEncoder[_]], - children: Seq[Expression], - rencoder: TypedEncoder[R] -) extends Expression with NonSQLExpression { + function: AnyRef, + encoders: Seq[TypedEncoder[_]], + children: Seq[Expression], + rencoder: TypedEncoder[R]) + extends Expression + with NonSQLExpression { override def nullable: Boolean = rencoder.nullable override def toString: String = s"FramelessUdf(${children.mkString(", ")})" @@ -118,10 +146,12 @@ case class FramelessUdf[T, R]( """ val code = CodeFormatter.stripOverlappingComments( - new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()) + ) val (clazz, _) = CodeGenerator.compile(code) - val codegen = clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] + val codegen = + clazz.generate(ctx.references.toArray).asInstanceOf[InternalRow => AnyRef] codegen } @@ -139,29 +169,45 @@ case class FramelessUdf[T, R]( val framelessUdfClassName = classOf[FramelessUdf[_, _]].getName val funcClassName = s"scala.Function${children.size}" val funcExpressionIdx = ctx.references.size - 1 - val funcTerm = ctx.addMutableState(funcClassName, ctx.freshName("udf"), - v => s"$v = ($funcClassName)((($framelessUdfClassName)references" + - s"[$funcExpressionIdx]).function());") - - val (argsCode, funcArguments) = encoders.zip(children).map { - case (encoder, child) => - val eval = child.genCode(ctx) - val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) - val argTerm = ctx.freshName("arg") - val convert = s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + val funcTerm = ctx.addMutableState( + funcClassName, + ctx.freshName("udf"), + v => + s"$v = ($funcClassName)((($framelessUdfClassName)references" + + s"[$funcExpressionIdx]).function());" + ) - (convert, argTerm) - }.unzip + val (argsCode, funcArguments) = encoders + .zip(children) + .map { + case (encoder, child) => + val eval = child.genCode(ctx) + val codeTpe = CodeGenerator.boxedType(encoder.jvmRepr) + val argTerm = ctx.freshName("arg") + val convert = + s"${eval.code}\n$codeTpe $argTerm = ${eval.isNull} ? (($codeTpe)null) : (($codeTpe)(${eval.value}));" + + (convert, argTerm) + } + .unzip val internalTpe = CodeGenerator.boxedType(rencoder.jvmRepr) - val internalTerm = ctx.addMutableState(internalTpe, ctx.freshName("internal")) - val internalNullTerm = ctx.addMutableState("boolean", ctx.freshName("internalNull")) + val internalTerm = + ctx.addMutableState(internalTpe, ctx.freshName("internal")) + val internalNullTerm = + ctx.addMutableState("boolean", ctx.freshName("internalNull")) // CTw - can't inject the term, may have to duplicate old code for parity - val internalExpr = Spark2_4_LambdaVariable(internalTerm, internalNullTerm, rencoder.jvmRepr, true) + val internalExpr = Spark2_4_LambdaVariable( + internalTerm, + internalNullTerm, + rencoder.jvmRepr, + true + ) val resultEval = rencoder.toCatalyst(internalExpr).genCode(ctx) - ev.copy(code = code""" + ev.copy( + code = code""" ${argsCode.mkString("\n")} $internalTerm = @@ -175,21 +221,28 @@ case class FramelessUdf[T, R]( ) } - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(children = newChildren) } case class Spark2_4_LambdaVariable( - value: String, - isNull: String, - dataType: DataType, - nullable: Boolean = true) extends LeafExpression with NonSQLExpression { + value: String, + isNull: String, + dataType: DataType, + nullable: Boolean = true) + extends LeafExpression + with NonSQLExpression { - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = + InternalRow.getAccessor(dataType) // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { - assert(input.numFields == 1, - "The input row of interpreted LambdaVariable should have only 1 field.") + assert( + input.numFields == 1, + "The input row of interpreted LambdaVariable should have only 1 field." + ) if (nullable && input.isNullAt(0)) { null } else { @@ -197,7 +250,10 @@ case class Spark2_4_LambdaVariable( } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + override protected def doGenCode( + ctx: CodegenContext, + ev: ExprCode + ): ExprCode = { val isNullValue = if (nullable) { JavaCode.isNullVariable(isNull) } else { @@ -208,12 +264,13 @@ case class Spark2_4_LambdaVariable( } object FramelessUdf { + // Spark needs case class with `children` field to mutate it def apply[T, R]( - function: AnyRef, - cols: Seq[UntypedExpression[T]], - rencoder: TypedEncoder[R] - ): FramelessUdf[T, R] = FramelessUdf( + function: AnyRef, + cols: Seq[UntypedExpression[T]], + rencoder: TypedEncoder[R] + ): FramelessUdf[T, R] = FramelessUdf( function = function, encoders = cols.map(_.uencoder).toList, children = cols.map(x => x.uencoder.fromCatalyst(x.expr)).toList, diff --git a/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala b/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala index 64bdf0ed1..6a9f41cdd 100644 --- a/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/UnaryFunctions.scala @@ -1,50 +1,74 @@ package frameless package functions -import org.apache.spark.sql.{Column, functions => sparkFunctions} +import org.apache.spark.sql.{ Column, functions => sparkFunctions } import scala.math.Ordering trait UnaryFunctions { - /** Returns length of array - * - * apache/spark - */ - def size[T, A, V[_] : CatalystSizableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, Int] = - new TypedColumn[T, Int](implicitly[CatalystSizableCollection[V]].sizeOp(column.untyped)) - - /** Returns length of Map - * - * apache/spark - */ + + /** + * Returns length of array + * + * apache/spark + */ + def size[T, A, V[_]: CatalystSizableCollection]( + column: TypedColumn[T, V[A]] + ): TypedColumn[T, Int] = + new TypedColumn[T, Int]( + implicitly[CatalystSizableCollection[V]].sizeOp(column.untyped) + ) + + /** + * Returns length of Map + * + * apache/spark + */ def size[T, A, B](column: TypedColumn[T, Map[A, B]]): TypedColumn[T, Int] = new TypedColumn[T, Int](sparkFunctions.size(column.untyped)) - /** Sorts the input array for the given column in ascending order, according to - * the natural ordering of the array elements. - * - * apache/spark - */ - def sortAscending[T, A: Ordering, V[_] : CatalystSortableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, V[A]] = - new TypedColumn[T, V[A]](implicitly[CatalystSortableCollection[V]].sortOp(column.untyped, sortAscending = true))(column.uencoder) - - /** Sorts the input array for the given column in descending order, according to - * the natural ordering of the array elements. - * - * apache/spark - */ - def sortDescending[T, A: Ordering, V[_] : CatalystSortableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, V[A]] = - new TypedColumn[T, V[A]](implicitly[CatalystSortableCollection[V]].sortOp(column.untyped, sortAscending = false))(column.uencoder) - - - /** Creates a new row for each element in the given collection. The column types - * eligible for this operation are constrained by CatalystExplodableCollection. - * - * apache/spark - */ - @deprecated("Use explode() from the TypedDataset instead. This method will result in " + - "runtime error if applied to two columns in the same select statement.", "0.6.2") - def explode[T, A: TypedEncoder, V[_] : CatalystExplodableCollection](column: TypedColumn[T, V[A]]): TypedColumn[T, A] = + /** + * Sorts the input array for the given column in ascending order, according to + * the natural ordering of the array elements. + * + * apache/spark + */ + def sortAscending[T, A: Ordering, V[_]: CatalystSortableCollection]( + column: TypedColumn[T, V[A]] + ): TypedColumn[T, V[A]] = + new TypedColumn[T, V[A]]( + implicitly[CatalystSortableCollection[V]] + .sortOp(column.untyped, sortAscending = true) + )(column.uencoder) + + /** + * Sorts the input array for the given column in descending order, according to + * the natural ordering of the array elements. + * + * apache/spark + */ + def sortDescending[T, A: Ordering, V[_]: CatalystSortableCollection]( + column: TypedColumn[T, V[A]] + ): TypedColumn[T, V[A]] = + new TypedColumn[T, V[A]]( + implicitly[CatalystSortableCollection[V]] + .sortOp(column.untyped, sortAscending = false) + )(column.uencoder) + + /** + * Creates a new row for each element in the given collection. The column types + * eligible for this operation are constrained by CatalystExplodableCollection. + * + * apache/spark + */ + @deprecated( + "Use explode() from the TypedDataset instead. This method will result in " + + "runtime error if applied to two columns in the same select statement.", + "0.6.2" + ) + def explode[T, A: TypedEncoder, V[_]: CatalystExplodableCollection]( + column: TypedColumn[T, V[A]] + ): TypedColumn[T, A] = new TypedColumn[T, A](sparkFunctions.explode(column.untyped)) } @@ -53,27 +77,39 @@ trait CatalystSizableCollection[V[_]] { } object CatalystSizableCollection { - implicit def sizableVector: CatalystSizableCollection[Vector] = new CatalystSizableCollection[Vector] { - def sizeOp(col: Column): Column = sparkFunctions.size(col) - } - implicit def sizableArray: CatalystSizableCollection[Array] = new CatalystSizableCollection[Array] { - def sizeOp(col: Column): Column = sparkFunctions.size(col) - } + implicit def sizableVector: CatalystSizableCollection[Vector] = + new CatalystSizableCollection[Vector] { + def sizeOp(col: Column): Column = sparkFunctions.size(col) + } + + implicit def sizableArray: CatalystSizableCollection[Array] = + new CatalystSizableCollection[Array] { + def sizeOp(col: Column): Column = sparkFunctions.size(col) + } - implicit def sizableList: CatalystSizableCollection[List] = new CatalystSizableCollection[List] { - def sizeOp(col: Column): Column = sparkFunctions.size(col) - } + implicit def sizableList: CatalystSizableCollection[List] = + new CatalystSizableCollection[List] { + def sizeOp(col: Column): Column = sparkFunctions.size(col) + } } trait CatalystExplodableCollection[V[_]] object CatalystExplodableCollection { - implicit def explodableVector: CatalystExplodableCollection[Vector] = new CatalystExplodableCollection[Vector] {} - implicit def explodableArray: CatalystExplodableCollection[Array] = new CatalystExplodableCollection[Array] {} - implicit def explodableList: CatalystExplodableCollection[List] = new CatalystExplodableCollection[List] {} - implicit def explodableSeq: CatalystExplodableCollection[Seq] = new CatalystExplodableCollection[Seq] {} + + implicit def explodableVector: CatalystExplodableCollection[Vector] = + new CatalystExplodableCollection[Vector] {} + + implicit def explodableArray: CatalystExplodableCollection[Array] = + new CatalystExplodableCollection[Array] {} + + implicit def explodableList: CatalystExplodableCollection[List] = + new CatalystExplodableCollection[List] {} + + implicit def explodableSeq: CatalystExplodableCollection[Seq] = + new CatalystExplodableCollection[Seq] {} } trait CatalystSortableCollection[V[_]] { @@ -81,15 +117,22 @@ trait CatalystSortableCollection[V[_]] { } object CatalystSortableCollection { - implicit def sortableVector: CatalystSortableCollection[Vector] = new CatalystSortableCollection[Vector] { - def sortOp(col: Column, sortAscending: Boolean): Column = sparkFunctions.sort_array(col, sortAscending) - } - - implicit def sortableArray: CatalystSortableCollection[Array] = new CatalystSortableCollection[Array] { - def sortOp(col: Column, sortAscending: Boolean): Column = sparkFunctions.sort_array(col, sortAscending) - } - implicit def sortableList: CatalystSortableCollection[List] = new CatalystSortableCollection[List] { - def sortOp(col: Column, sortAscending: Boolean): Column = sparkFunctions.sort_array(col, sortAscending) - } + implicit def sortableVector: CatalystSortableCollection[Vector] = + new CatalystSortableCollection[Vector] { + def sortOp(col: Column, sortAscending: Boolean): Column = + sparkFunctions.sort_array(col, sortAscending) + } + + implicit def sortableArray: CatalystSortableCollection[Array] = + new CatalystSortableCollection[Array] { + def sortOp(col: Column, sortAscending: Boolean): Column = + sparkFunctions.sort_array(col, sortAscending) + } + + implicit def sortableList: CatalystSortableCollection[List] = + new CatalystSortableCollection[List] { + def sortOp(col: Column, sortAscending: Boolean): Column = + sparkFunctions.sort_array(col, sortAscending) + } } diff --git a/dataset/src/main/scala/frameless/ops/AggregateTypes.scala b/dataset/src/main/scala/frameless/ops/AggregateTypes.scala index 403c25301..58d8cb27f 100644 --- a/dataset/src/main/scala/frameless/ops/AggregateTypes.scala +++ b/dataset/src/main/scala/frameless/ops/AggregateTypes.scala @@ -3,26 +3,32 @@ package ops import shapeless._ -/** A type class to extract the column types out of an HList of [[frameless.TypedAggregate]]. - * - * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped. - * @example - * {{{ - * type U = TypedAggregate[T,A] :: TypedAggregate[T,B] :: TypedAggregate[T,C] :: HNil - * type Out = A :: B :: C :: HNil - * }}} - */ +/** + * A type class to extract the column types out of an HList of [[frameless.TypedAggregate]]. + * + * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped. + * @example + * {{{ + * type U = TypedAggregate[T,A] :: TypedAggregate[T,B] :: TypedAggregate[T,C] :: HNil + * type Out = A :: B :: C :: HNil + * }}} + */ trait AggregateTypes[V, U <: HList] { type Out <: HList } object AggregateTypes { - type Aux[V, U <: HList, Out0 <: HList] = AggregateTypes[V, U] {type Out = Out0} - implicit def deriveHNil[T]: AggregateTypes.Aux[T, HNil, HNil] = new AggregateTypes[T, HNil] { type Out = HNil } + type Aux[V, U <: HList, Out0 <: HList] = AggregateTypes[V, U] { + type Out = Out0 + } + + implicit def deriveHNil[T]: AggregateTypes.Aux[T, HNil, HNil] = + new AggregateTypes[T, HNil] { type Out = HNil } implicit def deriveCons1[T, H, TT <: HList, V <: HList]( - implicit tail: AggregateTypes.Aux[T, TT, V] - ): AggregateTypes.Aux[T, TypedAggregate[T, H] :: TT, H :: V] = - new AggregateTypes[T, TypedAggregate[T, H] :: TT] {type Out = H :: V} + implicit + tail: AggregateTypes.Aux[T, TT, V] + ): AggregateTypes.Aux[T, TypedAggregate[T, H] :: TT, H :: V] = + new AggregateTypes[T, TypedAggregate[T, H] :: TT] { type Out = H :: V } } diff --git a/dataset/src/main/scala/frameless/ops/As.scala b/dataset/src/main/scala/frameless/ops/As.scala index 06b691028..e9c553f72 100644 --- a/dataset/src/main/scala/frameless/ops/As.scala +++ b/dataset/src/main/scala/frameless/ops/As.scala @@ -1,10 +1,12 @@ package frameless package ops -import shapeless.{::, Generic, HList, Lazy} +import shapeless.{ ::, Generic, HList, Lazy } /** Evidence for correctness of `TypedDataset[T].as[U]` */ -class As[T, U] private (implicit val encoder: TypedEncoder[U]) +class As[T, U] private ( + implicit + val encoder: TypedEncoder[U]) object As extends LowPriorityAs { @@ -12,8 +14,8 @@ object As extends LowPriorityAs { implicit def equivIdentity[A] = new Equiv[A, A] - implicit def deriveAs[A, B] - (implicit + implicit def deriveAs[A, B]( + implicit i0: TypedEncoder[B], i1: Equiv[A, B] ): As[A, B] = new As[A, B] @@ -24,14 +26,14 @@ trait LowPriorityAs { import As.Equiv - implicit def equivHList[AH, AT <: HList, BH, BT <: HList] - (implicit + implicit def equivHList[AH, AT <: HList, BH, BT <: HList]( + implicit i0: Lazy[Equiv[AH, BH]], i1: Equiv[AT, BT] ): Equiv[AH :: AT, BH :: BT] = new Equiv[AH :: AT, BH :: BT] - implicit def equivGeneric[A, B, R, S] - (implicit + implicit def equivGeneric[A, B, R, S]( + implicit i0: Generic.Aux[A, R], i1: Generic.Aux[B, S], i2: Lazy[Equiv[R, S]] diff --git a/dataset/src/main/scala/frameless/ops/ColumnTypes.scala b/dataset/src/main/scala/frameless/ops/ColumnTypes.scala index e5ae6aea2..923b0a168 100644 --- a/dataset/src/main/scala/frameless/ops/ColumnTypes.scala +++ b/dataset/src/main/scala/frameless/ops/ColumnTypes.scala @@ -3,26 +3,29 @@ package ops import shapeless._ -/** A type class to extract the column types out of an HList of [[frameless.TypedColumn]]. - * - * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped. - * @example - * {{{ - * type U = TypedColumn[T,A] :: TypedColumn[T,B] :: TypedColumn[T,C] :: HNil - * type Out = A :: B :: C :: HNil - * }}} - */ +/** + * A type class to extract the column types out of an HList of [[frameless.TypedColumn]]. + * + * @note This type class is mostly a workaround to issue with slow implicit derivation for Comapped. + * @example + * {{{ + * type U = TypedColumn[T,A] :: TypedColumn[T,B] :: TypedColumn[T,C] :: HNil + * type Out = A :: B :: C :: HNil + * }}} + */ trait ColumnTypes[T, U <: HList] { type Out <: HList } object ColumnTypes { - type Aux[T, U <: HList, Out0 <: HList] = ColumnTypes[T, U] {type Out = Out0} + type Aux[T, U <: HList, Out0 <: HList] = ColumnTypes[T, U] { type Out = Out0 } - implicit def deriveHNil[T]: ColumnTypes.Aux[T, HNil, HNil] = new ColumnTypes[T, HNil] { type Out = HNil } + implicit def deriveHNil[T]: ColumnTypes.Aux[T, HNil, HNil] = + new ColumnTypes[T, HNil] { type Out = HNil } implicit def deriveCons[T, H, TT <: HList, V <: HList]( - implicit tail: ColumnTypes.Aux[T, TT, V] - ): ColumnTypes.Aux[T, TypedColumn[T, H] :: TT, H :: V] = - new ColumnTypes[T, TypedColumn[T, H] :: TT] {type Out = H :: V} + implicit + tail: ColumnTypes.Aux[T, TT, V] + ): ColumnTypes.Aux[T, TypedColumn[T, H] :: TT, H :: V] = + new ColumnTypes[T, TypedColumn[T, H] :: TT] { type Out = H :: V } } diff --git a/dataset/src/main/scala/frameless/ops/GroupByOps.scala b/dataset/src/main/scala/frameless/ops/GroupByOps.scala index 3feeaca59..e4425e8f3 100644 --- a/dataset/src/main/scala/frameless/ops/GroupByOps.scala +++ b/dataset/src/main/scala/frameless/ops/GroupByOps.scala @@ -3,36 +3,54 @@ package ops import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.logical.Project -import org.apache.spark.sql.{Column, Dataset, FramelessInternals, RelationalGroupedDataset} +import org.apache.spark.sql.{ + Column, + Dataset, + FramelessInternals, + RelationalGroupedDataset +} import shapeless._ -import shapeless.ops.hlist.{Length, Mapped, Prepend, ToList, ToTraversable, Tupler} +import shapeless.ops.hlist.{ + Length, + Mapped, + Prepend, + ToList, + ToTraversable, + Tupler +} -class GroupedByManyOps[T, TK <: HList, K <: HList, KT] - (self: TypedDataset[T], groupedBy: TK) - (implicit +class GroupedByManyOps[T, TK <: HList, K <: HList, KT]( + self: TypedDataset[T], + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: ToTraversable.Aux[TK, List, UntypedExpression[T]], - i3: Tupler.Aux[K, KT] - ) extends AggregatingOps[T, TK, K, KT](self, groupedBy, (dataset, cols) => dataset.groupBy(cols: _*)) { + i3: Tupler.Aux[K, KT]) + extends AggregatingOps[T, TK, K, KT]( + self, + groupedBy, + (dataset, cols) => dataset.groupBy(cols: _*) + ) { + object agg extends ProductArgs { - def applyProduct[TC <: HList, C <: HList, Out0 <: HList, Out1] - (columns: TC) - (implicit + + def applyProduct[TC <: HList, C <: HList, Out0 <: HList, Out1]( + columns: TC + )(implicit i3: AggregateTypes.Aux[T, TC, C], i4: Prepend.Aux[K, C, Out0], i5: Tupler.Aux[Out0, Out1], i6: TypedEncoder[Out1], i7: ToTraversable.Aux[TC, List, UntypedExpression[T]] ): TypedDataset[Out1] = { - aggregate[TC, Out1](columns) - } + aggregate[TC, Out1](columns) + } } } class GroupedBy1Ops[K1, V]( - self: TypedDataset[V], - g1: TypedColumn[V, K1] -) { + self: TypedDataset[V], + g1: TypedColumn[V, K1]) { private def underlying = new GroupedByManyOps(self, g1 :: HNil) private implicit def eg1 = g1.uencoder @@ -41,49 +59,77 @@ class GroupedBy1Ops[K1, V]( underlying.agg(c1) } - def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(K1, U1, U2)] = { + def agg[U1, U2]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2] + ): TypedDataset[(K1, U1, U2)] = { implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder underlying.agg(c1, c2) } - def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(K1, U1, U2, U3)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder + def agg[U1, U2, U3]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3] + ): TypedDataset[(K1, U1, U2, U3)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder underlying.agg(c1, c2, c3) } - def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(K1, U1, U2, U3, U4)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder + def agg[U1, U2, U3, U4]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4] + ): TypedDataset[(K1, U1, U2, U3, U4)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder underlying.agg(c1, c2, c3, c4) } - def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(K1, U1, U2, U3, U4, U5)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder + def agg[U1, U2, U3, U4, U5]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4], + c5: TypedAggregate[V, U5] + ): TypedDataset[(K1, U1, U2, U3, U4, U5)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; + implicit val e5 = c5.uencoder underlying.agg(c1, c2, c3, c4, c5) } - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { - def mapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => U): TypedDataset[U] = { + + def mapGroups[U: TypedEncoder]( + f: (K1, Iterator[V]) => U + ): TypedDataset[U] = { underlying.deserialized.mapGroups(AggregatingOps.tuple1(f)) } - def flatMapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = { + def flatMapGroups[U: TypedEncoder]( + f: (K1, Iterator[V]) => TraversableOnce[U] + ): TypedDataset[U] = { underlying.deserialized.flatMapGroups(AggregatingOps.tuple1(f)) } } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): PivotNotValues[V, TypedColumn[V,K1] :: HNil, P] = + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[V, P] + ): PivotNotValues[V, TypedColumn[V, K1] :: HNil, P] = PivotNotValues(self, g1 :: HNil, pivotColumn) } - class GroupedBy2Ops[K1, K2, V]( - self: TypedDataset[V], - g1: TypedColumn[V, K1], - g2: TypedColumn[V, K2] -) { + self: TypedDataset[V], + g1: TypedColumn[V, K1], + g2: TypedColumn[V, K2]) { private def underlying = new GroupedByManyOps(self, g1 :: g2 :: HNil) private implicit def eg1 = g1.uencoder private implicit def eg2 = g2.uencoder @@ -93,57 +139,88 @@ class GroupedBy2Ops[K1, K2, V]( underlying.agg(c1) } - def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(K1, K2, U1, U2)] = { + def agg[U1, U2]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2] + ): TypedDataset[(K1, K2, U1, U2)] = { implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder underlying.agg(c1, c2) } - def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(K1, K2, U1, U2, U3)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder + def agg[U1, U2, U3]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3] + ): TypedDataset[(K1, K2, U1, U2, U3)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder underlying.agg(c1, c2, c3) } - def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(K1, K2, U1, U2, U3, U4)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder - underlying.agg(c1 , c2 , c3 , c4) + def agg[U1, U2, U3, U4]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4] + ): TypedDataset[(K1, K2, U1, U2, U3, U4)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder + underlying.agg(c1, c2, c3, c4) } - def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(K1, K2, U1, U2, U3, U4, U5)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder + def agg[U1, U2, U3, U4, U5]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4], + c5: TypedAggregate[V, U5] + ): TypedDataset[(K1, K2, U1, U2, U3, U4, U5)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; + implicit val e5 = c5.uencoder underlying.agg(c1, c2, c3, c4, c5) } - - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { - def mapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => U): TypedDataset[U] = { + + def mapGroups[U: TypedEncoder]( + f: ((K1, K2), Iterator[V]) => U + ): TypedDataset[U] = { underlying.deserialized.mapGroups(f) } - def flatMapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = { + def flatMapGroups[U: TypedEncoder]( + f: ((K1, K2), Iterator[V]) => TraversableOnce[U] + ): TypedDataset[U] = { underlying.deserialized.flatMapGroups(f) } } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): - PivotNotValues[V, TypedColumn[V,K1] :: TypedColumn[V, K2] :: HNil, P] = - PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn) + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[V, P] + ): PivotNotValues[V, TypedColumn[V, K1] :: TypedColumn[V, K2] :: HNil, P] = + PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn) } -private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] - (self: TypedDataset[T], groupedBy: TK, groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset) - (implicit +private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT]( + self: TypedDataset[T], + groupedBy: TK, + groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: ToTraversable.Aux[TK, List, UntypedExpression[T]], - i2: Tupler.Aux[K, KT] - ) { - def aggregate[TC <: HList, Out1](columns: TC) - (implicit - i7: TypedEncoder[Out1], - i8: ToTraversable.Aux[TC, List, UntypedExpression[T]] - ): TypedDataset[Out1] = { + i2: Tupler.Aux[K, KT]) { + + def aggregate[TC <: HList, Out1]( + columns: TC + )(implicit + i7: TypedEncoder[Out1], + i8: ToTraversable.Aux[TC, List, UntypedExpression[T]] + ): TypedDataset[Out1] = { def expr(c: UntypedExpression[T]): Column = new Column(c.expr) val groupByExprs = groupedBy.toList[UntypedExpression[T]].map(expr) @@ -159,25 +236,32 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] TypedDataset.create[Out1](aggregated) } - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { + def mapGroups[U: TypedEncoder]( - f: (KT, Iterator[T]) => U - )(implicit e: TypedEncoder[KT]): TypedDataset[U] = { + f: (KT, Iterator[T]) => U + )(implicit + e: TypedEncoder[KT] + ): TypedDataset[U] = { val func = (key: KT, it: Iterator[T]) => Iterator(f(key, it)) flatMapGroups(func) } def flatMapGroups[U: TypedEncoder]( - f: (KT, Iterator[T]) => TraversableOnce[U] - )(implicit e: TypedEncoder[KT]): TypedDataset[U] = { + f: (KT, Iterator[T]) => TraversableOnce[U] + )(implicit + e: TypedEncoder[KT] + ): TypedDataset[U] = { implicit val tendcoder = self.encoder val cols = groupedBy.toList[UntypedExpression[T]] val logicalPlan = FramelessInternals.logicalPlan(self.dataset) - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) + val withKeyColumns = + logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = FramelessInternals.executePlan(self.dataset, withKey) val keyAttributes = executed.analyzed.output.takeRight(cols.size) @@ -188,7 +272,11 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] keyAttributes, dataAttributes, executed.analyzed - )(TypedExpressionEncoder[KT], TypedExpressionEncoder[T], TypedExpressionEncoder[U]) + )( + TypedExpressionEncoder[KT], + TypedExpressionEncoder[T], + TypedExpressionEncoder[U] + ) val groupedAndFlatMapped = FramelessInternals.mkDataset( self.dataset.sqlContext, @@ -201,66 +289,97 @@ private[ops] abstract class AggregatingOps[T, TK <: HList, K <: HList, KT] } private def retainGroupColumns: Boolean = { - self.dataset.sqlContext.getConf("spark.sql.retainGroupColumns", "true").toBoolean + self.dataset.sqlContext + .getConf("spark.sql.retainGroupColumns", "true") + .toBoolean } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[T, P]): PivotNotValues[T, TK, P] = + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[T, P] + ): PivotNotValues[T, TK, P] = PivotNotValues(self, groupedBy, pivotColumn) } private[ops] object AggregatingOps { + /** Utility function to help Spark with serialization of closures */ - def tuple1[K1, V, U](f: (K1, Iterator[V]) => U): (Tuple1[K1], Iterator[V]) => U = { - (x: Tuple1[K1], it: Iterator[V]) => f(x._1, it) + def tuple1[K1, V, U]( + f: (K1, Iterator[V]) => U + ): (Tuple1[K1], Iterator[V]) => U = { (x: Tuple1[K1], it: Iterator[V]) => + f(x._1, it) } } -/** Represents a typed Pivot operation. - */ +/** + * Represents a typed Pivot operation. + */ final case class Pivot[T, GroupedColumns <: HList, PivotType, Values <: HList]( - ds: TypedDataset[T], - groupedBy: GroupedColumns, - pivotedBy: TypedColumn[T, PivotType], - values: Values -) { + ds: TypedDataset[T], + groupedBy: GroupedColumns, + pivotedBy: TypedColumn[T, PivotType], + values: Values) { object agg extends ProductArgs { - def applyProduct[AggrColumns <: HList, AggrColumnTypes <: HList, GroupedColumnTypes <: HList, NumValues <: Nat, TypesForPivotedValues <: HList, TypesForPivotedValuesOpt <: HList, OutAsHList <: HList, Out] - (aggrColumns: AggrColumns) - (implicit + + def applyProduct[ + AggrColumns <: HList, + AggrColumnTypes <: HList, + GroupedColumnTypes <: HList, + NumValues <: Nat, + TypesForPivotedValues <: HList, + TypesForPivotedValuesOpt <: HList, + OutAsHList <: HList, + Out + ](aggrColumns: AggrColumns + )(implicit i0: AggregateTypes.Aux[T, AggrColumns, AggrColumnTypes], i1: ColumnTypes.Aux[T, GroupedColumns, GroupedColumnTypes], i2: Length.Aux[Values, NumValues], i3: Repeat.Aux[AggrColumnTypes, NumValues, TypesForPivotedValues], i4: Mapped.Aux[TypesForPivotedValues, Option, TypesForPivotedValuesOpt], - i5: Prepend.Aux[GroupedColumnTypes, TypesForPivotedValuesOpt, OutAsHList], + i5: Prepend.Aux[ + GroupedColumnTypes, + TypesForPivotedValuesOpt, + OutAsHList + ], i6: Tupler.Aux[OutAsHList, Out], i7: TypedEncoder[Out] ): TypedDataset[Out] = { - def mapAny[X](h: HList)(f: Any => X): List[X] = - h match { - case HNil => Nil - case x :: xs => f(x) :: mapAny(xs)(f) - } - - val aggCols: Seq[Column] = mapAny(aggrColumns)(x => new Column(x.asInstanceOf[TypedAggregate[_,_]].expr)) - val tmp = ds.dataset.toDF() - .groupBy(mapAny(groupedBy)(_.asInstanceOf[TypedColumn[_, _]].untyped): _*) - .pivot(pivotedBy.untyped.toString, mapAny(values)(identity)) - .agg(aggCols.head, aggCols.tail:_*) - .as[Out](TypedExpressionEncoder[Out]) - TypedDataset.create(tmp) - } + def mapAny[X](h: HList)(f: Any => X): List[X] = + h match { + case HNil => Nil + case x :: xs => f(x) :: mapAny(xs)(f) + } + + val aggCols: Seq[Column] = mapAny(aggrColumns)(x => + new Column(x.asInstanceOf[TypedAggregate[_, _]].expr) + ) + val tmp = ds.dataset + .toDF() + .groupBy( + mapAny(groupedBy)(_.asInstanceOf[TypedColumn[_, _]].untyped): _* + ) + .pivot(pivotedBy.untyped.toString, mapAny(values)(identity)) + .agg(aggCols.head, aggCols.tail: _*) + .as[Out](TypedExpressionEncoder[Out]) + TypedDataset.create(tmp) + } } } final case class PivotNotValues[T, GroupedColumns <: HList, PivotType]( - ds: TypedDataset[T], - groupedBy: GroupedColumns, - pivotedBy: TypedColumn[T, PivotType] -) extends ProductArgs { - - def onProduct[Values <: HList](values: Values)( - implicit validValues: ToList[Values, PivotType] // validValues: FilterNot.Aux[Values, PivotType, HNil] // did not work - ): Pivot[T, GroupedColumns, PivotType, Values] = Pivot(ds, groupedBy, pivotedBy, values) + ds: TypedDataset[T], + groupedBy: GroupedColumns, + pivotedBy: TypedColumn[T, PivotType]) + extends ProductArgs { + + def onProduct[Values <: HList]( + values: Values + )(implicit + validValues: ToList[ + Values, + PivotType + ] // validValues: FilterNot.Aux[Values, PivotType, HNil] // did not work + ): Pivot[T, GroupedColumns, PivotType, Values] = + Pivot(ds, groupedBy, pivotedBy, values) } diff --git a/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala b/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala index 569407762..bfffdb1fb 100644 --- a/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala +++ b/dataset/src/main/scala/frameless/ops/RelationalGroupsOps.scala @@ -1,49 +1,74 @@ package frameless package ops -import org.apache.spark.sql.{Column, Dataset, RelationalGroupedDataset} -import shapeless.ops.hlist.{Mapped, Prepend, ToTraversable, Tupler} -import shapeless.{::, HList, HNil, ProductArgs} +import org.apache.spark.sql.{ Column, Dataset, RelationalGroupedDataset } +import shapeless.ops.hlist.{ Mapped, Prepend, ToTraversable, Tupler } +import shapeless.{ ::, HList, HNil, ProductArgs } /** - * @param groupingFunc functions used to group elements, can be cube or rollup - * @tparam T the original `TypedDataset's` type T - * @tparam TK all columns chosen for aggregation - * @tparam K individual columns' types as HList - * @tparam KT individual columns' types as Tuple - */ -private[ops] abstract class RelationalGroupsOps[T, TK <: HList, K <: HList, KT] - (self: TypedDataset[T], groupedBy: TK, groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset) - (implicit + * @param groupingFunc functions used to group elements, can be cube or rollup + * @tparam T the original `TypedDataset's` type T + * @tparam TK all columns chosen for aggregation + * @tparam K individual columns' types as HList + * @tparam KT individual columns' types as Tuple + */ +private[ops] abstract class RelationalGroupsOps[T, TK <: HList, K <: HList, KT]( + self: TypedDataset[T], + groupedBy: TK, + groupingFunc: (Dataset[T], Seq[Column]) => RelationalGroupedDataset + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: ToTraversable.Aux[TK, List, UntypedExpression[T]], - i2: Tupler.Aux[K, KT] - ) extends AggregatingOps(self, groupedBy, groupingFunc){ + i2: Tupler.Aux[K, KT]) + extends AggregatingOps(self, groupedBy, groupingFunc) { object agg extends ProductArgs { + /** - * @tparam TC resulting columns after aggregation function - * @tparam C individual columns' types as HList - * @tparam OptK columns' types mapped to Option - * @tparam Out0 OptK columns appended to C - * @tparam Out1 output type - */ - def applyProduct[TC <: HList, C <: HList, OptK <: HList, Out0 <: HList, Out1] - (columns: TC) - (implicit - i3: AggregateTypes.Aux[T, TC, C], // shares individual columns' types after agg function as HList - i4: Mapped.Aux[K, Option, OptK], // maps all original columns' types to Option - i5: Prepend.Aux[OptK, C, Out0], // concatenates Option columns with those resulting from applying agg function - i6: Tupler.Aux[Out0, Out1], // converts resulting HList into Tuple for output type - i7: TypedEncoder[Out1], // proof that there is `TypedEncoder` for the output type - i8: ToTraversable.Aux[TC, List, UntypedExpression[T]] // allows converting this HList to ordinary List - ): TypedDataset[Out1] = { + * @tparam TC resulting columns after aggregation function + * @tparam C individual columns' types as HList + * @tparam OptK columns' types mapped to Option + * @tparam Out0 OptK columns appended to C + * @tparam Out1 output type + */ + def applyProduct[ + TC <: HList, + C <: HList, + OptK <: HList, + Out0 <: HList, + Out1 + ](columns: TC + )(implicit + i3: AggregateTypes.Aux[ + T, + TC, + C + ], // shares individual columns' types after agg function as HList + i4: Mapped.Aux[ + K, + Option, + OptK + ], // maps all original columns' types to Option + i5: Prepend.Aux[OptK, C, Out0], // concatenates Option columns with those resulting from applying agg function + i6: Tupler.Aux[ + Out0, + Out1 + ], // converts resulting HList into Tuple for output type + i7: TypedEncoder[ + Out1 + ], // proof that there is `TypedEncoder` for the output type + i8: ToTraversable.Aux[TC, List, UntypedExpression[ + T + ]] // allows converting this HList to ordinary List + ): TypedDataset[Out1] = { aggregate[TC, Out1](columns) } } } -private[ops] abstract class RelationalGroups1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) { +private[ops] abstract class RelationalGroups1Ops[K1, V]( + self: TypedDataset[V], + g1: TypedColumn[V, K1]) { protected def underlying: RelationalGroupsOps[V, ::[TypedColumn[V, K1], HNil], ::[K1, HNil], Tuple1[K1]] private implicit def eg1 = g1.uencoder @@ -52,117 +77,203 @@ private[ops] abstract class RelationalGroups1Ops[K1, V](self: TypedDataset[V], g underlying.agg(c1) } - def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(Option[K1], U1, U2)] = { + def agg[U1, U2]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2] + ): TypedDataset[(Option[K1], U1, U2)] = { implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder underlying.agg(c1, c2) } - def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(Option[K1], U1, U2, U3)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder + def agg[U1, U2, U3]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3] + ): TypedDataset[(Option[K1], U1, U2, U3)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder underlying.agg(c1, c2, c3) } - def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(Option[K1], U1, U2, U3, U4)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder + def agg[U1, U2, U3, U4]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4] + ): TypedDataset[(Option[K1], U1, U2, U3, U4)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder underlying.agg(c1, c2, c3, c4) } - def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(Option[K1], U1, U2, U3, U4, U5)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder + def agg[U1, U2, U3, U4, U5]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4], + c5: TypedAggregate[V, U5] + ): TypedDataset[(Option[K1], U1, U2, U3, U4, U5)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; + implicit val e5 = c5.uencoder underlying.agg(c1, c2, c3, c4, c5) } - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { - def mapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => U): TypedDataset[U] = { + + def mapGroups[U: TypedEncoder]( + f: (K1, Iterator[V]) => U + ): TypedDataset[U] = { underlying.deserialized.mapGroups(AggregatingOps.tuple1(f)) } - def flatMapGroups[U: TypedEncoder](f: (K1, Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = { + def flatMapGroups[U: TypedEncoder]( + f: (K1, Iterator[V]) => TraversableOnce[U] + ): TypedDataset[U] = { underlying.deserialized.flatMapGroups(AggregatingOps.tuple1(f)) } } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): PivotNotValues[V, TypedColumn[V,K1] :: HNil, P] = + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[V, P] + ): PivotNotValues[V, TypedColumn[V, K1] :: HNil, P] = PivotNotValues(self, g1 :: HNil, pivotColumn) } -private[ops] abstract class RelationalGroups2Ops[K1, K2, V](self: TypedDataset[V], g1: TypedColumn[V, K1], g2: TypedColumn[V, K2]) { +private[ops] abstract class RelationalGroups2Ops[K1, K2, V]( + self: TypedDataset[V], + g1: TypedColumn[V, K1], + g2: TypedColumn[V, K2]) { protected def underlying: RelationalGroupsOps[V, ::[TypedColumn[V, K1], ::[TypedColumn[V, K2], HNil]], ::[K1, ::[K2, HNil]], (K1, K2)] private implicit def eg1 = g1.uencoder private implicit def eg2 = g2.uencoder - def agg[U1](c1: TypedAggregate[V, U1]): TypedDataset[(Option[K1], Option[K2], U1)] = { + def agg[U1]( + c1: TypedAggregate[V, U1] + ): TypedDataset[(Option[K1], Option[K2], U1)] = { implicit val e1 = c1.uencoder underlying.agg(c1) } - def agg[U1, U2](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2]): TypedDataset[(Option[K1], Option[K2], U1, U2)] = { + def agg[U1, U2]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2] + ): TypedDataset[(Option[K1], Option[K2], U1, U2)] = { implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder underlying.agg(c1, c2) } - def agg[U1, U2, U3](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3]): TypedDataset[(Option[K1], Option[K2], U1, U2, U3)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder + def agg[U1, U2, U3]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3] + ): TypedDataset[(Option[K1], Option[K2], U1, U2, U3)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder underlying.agg(c1, c2, c3) } - def agg[U1, U2, U3, U4](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4]): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder - underlying.agg(c1 , c2 , c3 , c4) + def agg[U1, U2, U3, U4]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4] + ): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder + underlying.agg(c1, c2, c3, c4) } - def agg[U1, U2, U3, U4, U5](c1: TypedAggregate[V, U1], c2: TypedAggregate[V, U2], c3: TypedAggregate[V, U3], c4: TypedAggregate[V, U4], c5: TypedAggregate[V, U5]): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4, U5)] = { - implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; implicit val e5 = c5.uencoder + def agg[U1, U2, U3, U4, U5]( + c1: TypedAggregate[V, U1], + c2: TypedAggregate[V, U2], + c3: TypedAggregate[V, U3], + c4: TypedAggregate[V, U4], + c5: TypedAggregate[V, U5] + ): TypedDataset[(Option[K1], Option[K2], U1, U2, U3, U4, U5)] = { + implicit val e1 = c1.uencoder; implicit val e2 = c2.uencoder; + implicit val e3 = c3.uencoder; implicit val e4 = c4.uencoder; + implicit val e5 = c5.uencoder underlying.agg(c1, c2, c3, c4, c5) } - /** Methods on `TypedDataset[T]` that go through a full serialization and - * deserialization of `T`, and execute outside of the Catalyst runtime. - */ + /** + * Methods on `TypedDataset[T]` that go through a full serialization and + * deserialization of `T`, and execute outside of the Catalyst runtime. + */ object deserialized { - def mapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => U): TypedDataset[U] = { + + def mapGroups[U: TypedEncoder]( + f: ((K1, K2), Iterator[V]) => U + ): TypedDataset[U] = { underlying.deserialized.mapGroups(f) } - def flatMapGroups[U: TypedEncoder](f: ((K1, K2), Iterator[V]) => TraversableOnce[U]): TypedDataset[U] = { + def flatMapGroups[U: TypedEncoder]( + f: ((K1, K2), Iterator[V]) => TraversableOnce[U] + ): TypedDataset[U] = { underlying.deserialized.flatMapGroups(f) } } - def pivot[P: CatalystPivotable](pivotColumn: TypedColumn[V, P]): - PivotNotValues[V, TypedColumn[V,K1] :: TypedColumn[V, K2] :: HNil, P] = + def pivot[P: CatalystPivotable]( + pivotColumn: TypedColumn[V, P] + ): PivotNotValues[V, TypedColumn[V, K1] :: TypedColumn[V, K2] :: HNil, P] = PivotNotValues(self, g1 :: g2 :: HNil, pivotColumn) } -class RollupManyOps[T, TK <: HList, K <: HList, KT](self: TypedDataset[T], groupedBy: TK) - (implicit +class RollupManyOps[T, TK <: HList, K <: HList, KT]( + self: TypedDataset[T], + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: ToTraversable.Aux[TK, List, UntypedExpression[T]], - i2: Tupler.Aux[K, KT] - ) extends RelationalGroupsOps[T, TK, K, KT](self, groupedBy, (dataset, cols) => dataset.rollup(cols: _*)) + i2: Tupler.Aux[K, KT]) + extends RelationalGroupsOps[T, TK, K, KT]( + self, + groupedBy, + (dataset, cols) => dataset.rollup(cols: _*) + ) -class Rollup1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) extends RelationalGroups1Ops(self, g1) { +class Rollup1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) + extends RelationalGroups1Ops(self, g1) { override protected def underlying = new RollupManyOps(self, g1 :: HNil) } -class Rollup2Ops[K1, K2, V](self: TypedDataset[V], g1: TypedColumn[V, K1], g2: TypedColumn[V, K2]) extends RelationalGroups2Ops(self, g1, g2) { +class Rollup2Ops[K1, K2, V]( + self: TypedDataset[V], + g1: TypedColumn[V, K1], + g2: TypedColumn[V, K2]) + extends RelationalGroups2Ops(self, g1, g2) { override protected def underlying = new RollupManyOps(self, g1 :: g2 :: HNil) } -class CubeManyOps[T, TK <: HList, K <: HList, KT](self: TypedDataset[T], groupedBy: TK) - (implicit +class CubeManyOps[T, TK <: HList, K <: HList, KT]( + self: TypedDataset[T], + groupedBy: TK + )(implicit i0: ColumnTypes.Aux[T, TK, K], i1: ToTraversable.Aux[TK, List, UntypedExpression[T]], - i2: Tupler.Aux[K, KT] - ) extends RelationalGroupsOps[T, TK, K, KT](self, groupedBy, (dataset, cols) => dataset.cube(cols: _*)) + i2: Tupler.Aux[K, KT]) + extends RelationalGroupsOps[T, TK, K, KT]( + self, + groupedBy, + (dataset, cols) => dataset.cube(cols: _*) + ) -class Cube1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) extends RelationalGroups1Ops(self, g1) { +class Cube1Ops[K1, V](self: TypedDataset[V], g1: TypedColumn[V, K1]) + extends RelationalGroups1Ops(self, g1) { override protected def underlying = new CubeManyOps(self, g1 :: HNil) } -class Cube2Ops[K1, K2, V](self: TypedDataset[V], g1: TypedColumn[V, K1], g2: TypedColumn[V, K2]) extends RelationalGroups2Ops(self, g1, g2) { +class Cube2Ops[K1, K2, V]( + self: TypedDataset[V], + g1: TypedColumn[V, K1], + g2: TypedColumn[V, K2]) + extends RelationalGroups2Ops(self, g1, g2) { override protected def underlying = new CubeManyOps(self, g1 :: g2 :: HNil) } diff --git a/dataset/src/main/scala/frameless/ops/Repeat.scala b/dataset/src/main/scala/frameless/ops/Repeat.scala index bde855500..b9f4d7c94 100644 --- a/dataset/src/main/scala/frameless/ops/Repeat.scala +++ b/dataset/src/main/scala/frameless/ops/Repeat.scala @@ -1,33 +1,37 @@ package frameless package ops -import shapeless.{HList, Nat, Succ} +import shapeless.{ HList, Nat, Succ } import shapeless.ops.hlist.Prepend -/** Typeclass supporting repeating L-typed HLists N times. - * - * Repeat[Int :: String :: HNil, Nat._2].Out =:= - * Int :: String :: Int :: String :: HNil - * - * By Jeremy Smith. To be replaced by `shapeless.ops.hlists.Repeat` - * once (https://github.com/milessabin/shapeless/pull/730 is published. - */ +/** + * Typeclass supporting repeating L-typed HLists N times. + * + * Repeat[Int :: String :: HNil, Nat._2].Out =:= + * Int :: String :: Int :: String :: HNil + * + * By Jeremy Smith. To be replaced by `shapeless.ops.hlists.Repeat` + * once (https://github.com/milessabin/shapeless/pull/730 is published. + */ trait Repeat[L <: HList, N <: Nat] { type Out <: HList } object Repeat { - type Aux[L <: HList, N <: Nat, Out0 <: HList] = Repeat[L, N] { type Out = Out0 } + + type Aux[L <: HList, N <: Nat, Out0 <: HList] = Repeat[L, N] { + type Out = Out0 + } implicit def base[L <: HList]: Aux[L, Nat._1, L] = new Repeat[L, Nat._1] { type Out = L } - implicit def succ[L <: HList, Prev <: Nat, PrevOut <: HList, P <: HList] - (implicit + implicit def succ[L <: HList, Prev <: Nat, PrevOut <: HList, P <: HList]( + implicit i0: Aux[L, Prev, PrevOut], i1: Prepend.Aux[L, PrevOut, P] ): Aux[L, Succ[Prev], P] = new Repeat[L, Succ[Prev]] { - type Out = P - } + type Out = P + } } diff --git a/dataset/src/main/scala/frameless/ops/SmartProject.scala b/dataset/src/main/scala/frameless/ops/SmartProject.scala index ec3628efd..7242975f4 100644 --- a/dataset/src/main/scala/frameless/ops/SmartProject.scala +++ b/dataset/src/main/scala/frameless/ops/SmartProject.scala @@ -2,38 +2,47 @@ package frameless package ops import shapeless.ops.hlist.ToTraversable -import shapeless.ops.record.{Keys, SelectAll, Values} -import shapeless.{HList, LabelledGeneric} +import shapeless.ops.record.{ Keys, SelectAll, Values } +import shapeless.{ HList, LabelledGeneric } import scala.annotation.implicitNotFound @implicitNotFound(msg = "Cannot prove that ${T} can be projected to ${U}. Perhaps not all member names and types of ${U} are the same in ${T}?") -case class SmartProject[T: TypedEncoder, U: TypedEncoder](apply: TypedDataset[T] => TypedDataset[U]) +case class SmartProject[T: TypedEncoder, U: TypedEncoder]( + apply: TypedDataset[T] => TypedDataset[U]) object SmartProject { + /** - * Proofs that there is a type-safe projection from a type T to another type U. It requires that: - * (a) both T and U are Products for which a LabelledGeneric can be derived (e.g., case classes), - * (b) all members of U have a corresponding member in T that has both the same name and type. - * - * @param i0 the LabelledGeneric derived for T - * @param i1 the LabelledGeneric derived for U - * @param i2 the keys of U - * @param i3 selects all the values from T using the keys of U - * @param i4 selects all the values of LabeledGeneric[U] - * @param i5 proof that U and the projection of T have the same type - * @param i6 allows for traversing the keys of U - * @tparam T the original type T - * @tparam U the projected type U - * @tparam TRec shapeless' Record representation of T - * @tparam TProj the projection of T using the keys of U - * @tparam URec shapeless' Record representation of U - * @tparam UVals the values of U as an HList - * @tparam UKeys the keys of U as an HList - * @return a projection if it exists - */ - implicit def deriveProduct[T: TypedEncoder, U: TypedEncoder, TRec <: HList, TProj <: HList, URec <: HList, UVals <: HList, UKeys <: HList] - (implicit + * Proofs that there is a type-safe projection from a type T to another type U. It requires that: + * (a) both T and U are Products for which a LabelledGeneric can be derived (e.g., case classes), + * (b) all members of U have a corresponding member in T that has both the same name and type. + * + * @param i0 the LabelledGeneric derived for T + * @param i1 the LabelledGeneric derived for U + * @param i2 the keys of U + * @param i3 selects all the values from T using the keys of U + * @param i4 selects all the values of LabeledGeneric[U] + * @param i5 proof that U and the projection of T have the same type + * @param i6 allows for traversing the keys of U + * @tparam T the original type T + * @tparam U the projected type U + * @tparam TRec shapeless' Record representation of T + * @tparam TProj the projection of T using the keys of U + * @tparam URec shapeless' Record representation of U + * @tparam UVals the values of U as an HList + * @tparam UKeys the keys of U as an HList + * @return a projection if it exists + */ + implicit def deriveProduct[ + T: TypedEncoder, + U: TypedEncoder, + TRec <: HList, + TProj <: HList, + URec <: HList, + UVals <: HList, + UKeys <: HList + ](implicit i0: LabelledGeneric.Aux[T, TRec], i1: LabelledGeneric.Aux[U, URec], i2: Keys.Aux[URec, UKeys], @@ -41,8 +50,14 @@ object SmartProject { i4: Values.Aux[URec, UVals], i5: UVals =:= TProj, i6: ToTraversable.Aux[UKeys, Seq, Symbol] - ): SmartProject[T,U] = SmartProject[T, U]({ from => - val names = implicitly[Keys.Aux[URec, UKeys]].apply().to[Seq].map(_.name).map(from.dataset.col) - TypedDataset.create(from.dataset.toDF().select(names: _*).as[U](TypedExpressionEncoder[U])) - }) + ): SmartProject[T, U] = SmartProject[T, U]({ from => + val names = implicitly[Keys.Aux[URec, UKeys]] + .apply() + .to[Seq] + .map(_.name) + .map(from.dataset.col) + TypedDataset.create( + from.dataset.toDF().select(names: _*).as[U](TypedExpressionEncoder[U]) + ) + }) } diff --git a/dataset/src/main/scala/frameless/syntax/package.scala b/dataset/src/main/scala/frameless/syntax/package.scala index c6045981f..7050f93e8 100644 --- a/dataset/src/main/scala/frameless/syntax/package.scala +++ b/dataset/src/main/scala/frameless/syntax/package.scala @@ -1,5 +1,7 @@ package frameless package object syntax extends FramelessSyntax { - implicit val DefaultSparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob + + implicit val DefaultSparkDelay: SparkDelay[Job] = + Job.framelessSparkDelayForJob } diff --git a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala b/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala index 5459230d4..5cdb34155 100644 --- a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala +++ b/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala @@ -2,24 +2,34 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct} -import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{ Alias, CreateStruct } +import org.apache.spark.sql.catalyst.expressions.{ Expression, NamedExpression } import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{ LogicalPlan, Project } import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types._ import org.apache.spark.sql.types.ObjectType import scala.reflect.ClassTag object FramelessInternals { - def objectTypeFor[A](implicit classTag: ClassTag[A]): ObjectType = ObjectType(classTag.runtimeClass) + + def objectTypeFor[A]( + implicit + classTag: ClassTag[A] + ): ObjectType = ObjectType(classTag.runtimeClass) def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = { - ds.toDF().queryExecution.analyzed.resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver).getOrElse { - throw new AnalysisException( - s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames.mkString(", ")})""") - } + ds.toDF() + .queryExecution + .analyzed + .resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver) + .getOrElse { + throw new AnalysisException( + s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames + .mkString(", ")})""" + ) + } } def expr(column: Column): Expression = column.expr @@ -29,45 +39,72 @@ object FramelessInternals { def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution = ds.sparkSession.sessionState.executePlan(plan) - def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = { + def joinPlan( + ds: Dataset[_], + plan: LogicalPlan, + leftPlan: LogicalPlan, + rightPlan: LogicalPlan + ): LogicalPlan = { val joined = executePlan(ds, plan) val leftOutput = joined.analyzed.output.take(leftPlan.output.length) val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length) - Project(List( - Alias(CreateStruct(leftOutput), "_1")(), - Alias(CreateStruct(rightOutput), "_2")() - ), joined.analyzed) + Project( + List( + Alias(CreateStruct(leftOutput), "_1")(), + Alias(CreateStruct(rightOutput), "_2")() + ), + joined.analyzed + ) } - def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] = + def mkDataset[T]( + sqlContext: SQLContext, + plan: LogicalPlan, + encoder: Encoder[T] + ): Dataset[T] = new Dataset(sqlContext, plan, encoder) def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = Dataset.ofRows(sparkSession, logicalPlan) // because org.apache.spark.sql.types.UserDefinedType is private[spark] - type UserDefinedType[A >: Null] = org.apache.spark.sql.types.UserDefinedType[A] + type UserDefinedType[A >: Null] = + org.apache.spark.sql.types.UserDefinedType[A] // below only tested in SelfJoinTests.colLeft and colRight are equivalent to col outside of joins // - via files (codegen) forces doGenCode eval. /** Expression to tag columns from the left hand side of join expression. */ - case class DisambiguateLeft[T](tagged: Expression) extends Expression with NonSQLExpression { + case class DisambiguateLeft[T](tagged: Expression) + extends Expression + with NonSQLExpression { def eval(input: InternalRow): Any = tagged.eval(input) def nullable: Boolean = false def children: Seq[Expression] = tagged :: Nil def dataType: DataType = tagged.dataType - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = tagged.genCode(ctx) - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) + + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + tagged.genCode(ctx) + + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(newChildren.head) } /** Expression to tag columns from the right hand side of join expression. */ - case class DisambiguateRight[T](tagged: Expression) extends Expression with NonSQLExpression { + case class DisambiguateRight[T](tagged: Expression) + extends Expression + with NonSQLExpression { def eval(input: InternalRow): Any = tagged.eval(input) def nullable: Boolean = false def children: Seq[Expression] = tagged :: Nil def dataType: DataType = tagged.dataType - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = tagged.genCode(ctx) - protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(newChildren.head) + + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + tagged.genCode(ctx) + + protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression] + ): Expression = copy(newChildren.head) } } diff --git a/dataset/src/main/spark-3.4+/frameless/MapGroups.scala b/dataset/src/main/spark-3.4+/frameless/MapGroups.scala index 6856acba4..25411420b 100644 --- a/dataset/src/main/spark-3.4+/frameless/MapGroups.scala +++ b/dataset/src/main/spark-3.4+/frameless/MapGroups.scala @@ -2,15 +2,19 @@ package frameless import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MapGroups => SMapGroups} +import org.apache.spark.sql.catalyst.plans.logical.{ + LogicalPlan, + MapGroups => SMapGroups +} object MapGroups { + def apply[K: Encoder, T: Encoder, U: Encoder]( - func: (K, Iterator[T]) => TraversableOnce[U], - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - child: LogicalPlan - ): LogicalPlan = + func: (K, Iterator[T]) => TraversableOnce[U], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan + ): LogicalPlan = SMapGroups( func, groupingAttributes, diff --git a/dataset/src/main/spark-3/frameless/MapGroups.scala b/dataset/src/main/spark-3/frameless/MapGroups.scala index 3fd27f333..67ec8b731 100644 --- a/dataset/src/main/spark-3/frameless/MapGroups.scala +++ b/dataset/src/main/spark-3/frameless/MapGroups.scala @@ -2,13 +2,17 @@ package frameless import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, MapGroups => SMapGroups} +import org.apache.spark.sql.catalyst.plans.logical.{ + LogicalPlan, + MapGroups => SMapGroups +} object MapGroups { + def apply[K: Encoder, T: Encoder, U: Encoder]( - func: (K, Iterator[T]) => TraversableOnce[U], - groupingAttributes: Seq[Attribute], - dataAttributes: Seq[Attribute], - child: LogicalPlan - ): LogicalPlan = SMapGroups(func, groupingAttributes, dataAttributes, child) + func: (K, Iterator[T]) => TraversableOnce[U], + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan + ): LogicalPlan = SMapGroups(func, groupingAttributes, dataAttributes, child) } diff --git a/dataset/src/test/scala/frameless/AsTests.scala b/dataset/src/test/scala/frameless/AsTests.scala index c1091f9ca..5dd19a6ce 100644 --- a/dataset/src/test/scala/frameless/AsTests.scala +++ b/dataset/src/test/scala/frameless/AsTests.scala @@ -5,14 +5,15 @@ import org.scalacheck.Prop._ class AsTests extends TypedDatasetSuite { test("as[X2[A, B]]") { - def prop[A, B](data: Vector[(A, B)])( - implicit - eab: TypedEncoder[(A, B)], - ex2: TypedEncoder[X2[A, B]] - ): Prop = { + def prop[A, B]( + data: Vector[(A, B)] + )(implicit + eab: TypedEncoder[(A, B)], + ex2: TypedEncoder[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) - val dataset2 = dataset.as[X2[A,B]]().collect().run().toVector + val dataset2 = dataset.as[X2[A, B]]().collect().run().toVector val data2 = data.map { case (a, b) => X2(a, b) } dataset2 ?= data2 @@ -27,17 +28,16 @@ class AsTests extends TypedDatasetSuite { } test("as[X2[X2[A, B], C]") { - def prop[A, B, C](data: Vector[(A, B, C)])( - implicit - eab: TypedEncoder[((A, B), C)], - ex2: TypedEncoder[X2[X2[A, B], C]] - ): Prop = { - val data2 = data.map { - case (a, b, c) => ((a, b), c) - } + def prop[A, B, C]( + data: Vector[(A, B, C)] + )(implicit + eab: TypedEncoder[((A, B), C)], + ex2: TypedEncoder[X2[X2[A, B], C]] + ): Prop = { + val data2 = data.map { case (a, b, c) => ((a, b), c) } val dataset = TypedDataset.create(data2) - val dataset2 = dataset.as[X2[X2[A,B], C]]().collect().run().toVector + val dataset2 = dataset.as[X2[X2[A, B], C]]().collect().run().toVector val data3 = data2.map { case ((a, b), c) => X2(X2(a, b), c) } dataset2 ?= data3 @@ -47,7 +47,13 @@ class AsTests extends TypedDatasetSuite { check(forAll(prop[String, Int, String] _)) check(forAll(prop[String, String, Int] _)) check(forAll(prop[Long, Int, String] _)) - check(forAll(prop[Seq[Seq[Option[Seq[Long]]]], Seq[Int], Option[Seq[Option[Int]]]] _)) - check(forAll(prop[Seq[Option[Seq[String]]], Seq[Int], Seq[Option[String]]] _)) + check( + forAll( + prop[Seq[Seq[Option[Seq[Long]]]], Seq[Int], Option[Seq[Option[Int]]]] _ + ) + ) + check( + forAll(prop[Seq[Option[Seq[String]]], Seq[Int], Seq[Option[String]]] _) + ) } } diff --git a/dataset/src/test/scala/frameless/BitwiseTests.scala b/dataset/src/test/scala/frameless/BitwiseTests.scala index f58c906a2..0d1914de0 100644 --- a/dataset/src/test/scala/frameless/BitwiseTests.scala +++ b/dataset/src/test/scala/frameless/BitwiseTests.scala @@ -7,12 +7,12 @@ import org.scalatest.matchers.should.Matchers class BitwiseTests extends TypedDatasetSuite with Matchers { /** - * providing instances with implementations for bitwise operations since in the tests - * we need to check the results from frameless vs the results from normal scala operators - * for Numeric it is easy to test since scala comes with Numeric typeclass but there seems - * to be no equivalent typeclass for bitwise ops for Byte Short Int and Long types supported in Catalyst - */ - trait CatalystBitwise4Tests[A]{ + * providing instances with implementations for bitwise operations since in the tests + * we need to check the results from frameless vs the results from normal scala operators + * for Numeric it is easy to test since scala comes with Numeric typeclass but there seems + * to be no equivalent typeclass for bitwise ops for Byte Short Int and Long types supported in Catalyst + */ + trait CatalystBitwise4Tests[A] { def bitwiseAnd(a1: A, a2: A): A def bitwiseOr(a1: A, a2: A): A def bitwiseXor(a1: A, a2: A): A @@ -22,33 +22,44 @@ class BitwiseTests extends TypedDatasetSuite with Matchers { } object CatalystBitwise4Tests { - implicit val framelessbyteBitwise : CatalystBitwise4Tests[Byte] = new CatalystBitwise4Tests[Byte] { - def bitwiseOr(a1: Byte, a2: Byte) : Byte = (a1 | a2).toByte - def bitwiseAnd(a1: Byte, a2: Byte): Byte = (a1 & a2).toByte - def bitwiseXor(a1: Byte, a2: Byte): Byte = (a1 ^ a2).toByte - } - implicit val framelessshortBitwise : CatalystBitwise4Tests[Short] = new CatalystBitwise4Tests[Short] { - def bitwiseOr(a1: Short, a2: Short) : Short = (a1 | a2).toShort - def bitwiseAnd(a1: Short, a2: Short): Short = (a1 & a2).toShort - def bitwiseXor(a1: Short, a2: Short): Short = (a1 ^ a2).toShort - } - implicit val framelessintBitwise : CatalystBitwise4Tests[Int] = new CatalystBitwise4Tests[Int] { - def bitwiseOr(a1: Int, a2: Int) : Int = a1 | a2 - def bitwiseAnd(a1: Int, a2: Int): Int = a1 & a2 - def bitwiseXor(a1: Int, a2: Int): Int = a1 ^ a2 - } - implicit val framelesslongBitwise : CatalystBitwise4Tests[Long] = new CatalystBitwise4Tests[Long] { - def bitwiseOr(a1: Long, a2: Long) : Long = a1 | a2 - def bitwiseAnd(a1: Long, a2: Long): Long = a1 & a2 - def bitwiseXor(a1: Long, a2: Long): Long = a1 ^ a2 - } + + implicit val framelessbyteBitwise: CatalystBitwise4Tests[Byte] = + new CatalystBitwise4Tests[Byte] { + def bitwiseOr(a1: Byte, a2: Byte): Byte = (a1 | a2).toByte + def bitwiseAnd(a1: Byte, a2: Byte): Byte = (a1 & a2).toByte + def bitwiseXor(a1: Byte, a2: Byte): Byte = (a1 ^ a2).toByte + } + + implicit val framelessshortBitwise: CatalystBitwise4Tests[Short] = + new CatalystBitwise4Tests[Short] { + def bitwiseOr(a1: Short, a2: Short): Short = (a1 | a2).toShort + def bitwiseAnd(a1: Short, a2: Short): Short = (a1 & a2).toShort + def bitwiseXor(a1: Short, a2: Short): Short = (a1 ^ a2).toShort + } + + implicit val framelessintBitwise: CatalystBitwise4Tests[Int] = + new CatalystBitwise4Tests[Int] { + def bitwiseOr(a1: Int, a2: Int): Int = a1 | a2 + def bitwiseAnd(a1: Int, a2: Int): Int = a1 & a2 + def bitwiseXor(a1: Int, a2: Int): Int = a1 ^ a2 + } + + implicit val framelesslongBitwise: CatalystBitwise4Tests[Long] = + new CatalystBitwise4Tests[Long] { + def bitwiseOr(a1: Long, a2: Long): Long = a1 | a2 + def bitwiseAnd(a1: Long, a2: Long): Long = a1 & a2 + def bitwiseXor(a1: Long, a2: Long): Long = a1 ^ a2 + } } import CatalystBitwise4Tests._ test("bitwiseAND") { - def prop[A: TypedEncoder: CatalystBitwise](a: A, b: A)( - implicit catalystBitwise4Tests: CatalystBitwise4Tests[A] - ): Prop = { + def prop[A: TypedEncoder: CatalystBitwise]( + a: A, + b: A + )(implicit + catalystBitwise4Tests: CatalystBitwise4Tests[A] + ): Prop = { val df = TypedDataset.create(X2(a, b) :: Nil) val result = implicitly[CatalystBitwise4Tests[A]].bitwiseAnd(a, b) val resultSymbolic = implicitly[CatalystBitwise4Tests[A]].&(a, b) @@ -56,7 +67,9 @@ class BitwiseTests extends TypedDatasetSuite with Matchers { val gotSymbolic = df.select(df.col('a) & b).collect().run() val symbolicCol2Col = df.select(df.col('a) & df.col('b)).collect().run() val canCast = df.select(df.col('a).cast[Long] & 0L).collect().run() - canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(0L) + canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)( + 0L + ) result ?= resultSymbolic symbolicCol2Col ?= (result :: Nil) got ?= (result :: Nil) @@ -70,9 +83,12 @@ class BitwiseTests extends TypedDatasetSuite with Matchers { } test("bitwiseOR") { - def prop[A: TypedEncoder: CatalystBitwise](a: A, b: A)( - implicit catalystBitwise4Tests: CatalystBitwise4Tests[A] - ): Prop = { + def prop[A: TypedEncoder: CatalystBitwise]( + a: A, + b: A + )(implicit + catalystBitwise4Tests: CatalystBitwise4Tests[A] + ): Prop = { val df = TypedDataset.create(X2(a, b) :: Nil) val result = implicitly[CatalystBitwise4Tests[A]].bitwiseOr(a, b) val resultSymbolic = implicitly[CatalystBitwise4Tests[A]].|(a, b) @@ -80,7 +96,9 @@ class BitwiseTests extends TypedDatasetSuite with Matchers { val gotSymbolic = df.select(df.col('a) | b).collect().run() val symbolicCol2Col = df.select(df.col('a) | df.col('b)).collect().run() val canCast = df.select(df.col('a).cast[Long] | -1L).collect().run() - canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(-1L) + canCast should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)( + -1L + ) result ?= resultSymbolic symbolicCol2Col ?= (result :: Nil) got ?= (result :: Nil) @@ -94,9 +112,12 @@ class BitwiseTests extends TypedDatasetSuite with Matchers { } test("bitwiseXOR") { - def prop[A: TypedEncoder: CatalystBitwise](a: A, b: A)( - implicit catalystBitwise4Tests: CatalystBitwise4Tests[A] - ): Prop = { + def prop[A: TypedEncoder: CatalystBitwise]( + a: A, + b: A + )(implicit + catalystBitwise4Tests: CatalystBitwise4Tests[A] + ): Prop = { val df = TypedDataset.create(X2(a, b) :: Nil) val result = implicitly[CatalystBitwise4Tests[A]].bitwiseXor(a, b) val resultSymbolic = implicitly[CatalystBitwise4Tests[A]].^(a, b) @@ -104,7 +125,9 @@ class BitwiseTests extends TypedDatasetSuite with Matchers { val got = df.select(df.col('a) bitwiseXOR df.col('b)).collect().run() val gotSymbolic = df.select(df.col('a) ^ b).collect().run() val zeroes = df.select(df.col('a) ^ df.col('a)).collect().run() - zeroes should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)(0L) + zeroes should contain theSameElementsAs Seq.fill[Long](gotSymbolic.size)( + 0L + ) got ?= (result :: Nil) gotSymbolic ?= (resultSymbolic :: Nil) } diff --git a/dataset/src/test/scala/frameless/CastTests.scala b/dataset/src/test/scala/frameless/CastTests.scala index 5f79f8fa6..00328d16c 100644 --- a/dataset/src/test/scala/frameless/CastTests.scala +++ b/dataset/src/test/scala/frameless/CastTests.scala @@ -1,14 +1,16 @@ package frameless -import org.scalacheck.{Arbitrary, Gen, Prop} +import org.scalacheck.{ Arbitrary, Gen, Prop } import org.scalacheck.Prop._ class CastTests extends TypedDatasetSuite { - def prop[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A)( - implicit - cast: CatalystCast[A, B] - ): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder]( + f: A => B + )(a: A + )(implicit + cast: CatalystCast[A, B] + ): Prop = { val df = TypedDataset.create(X1(a) :: Nil) val got = df.select(df.col('a).cast[B]).collect().run() @@ -100,9 +102,11 @@ class CastTests extends TypedDatasetSuite { check(prop[Short, Boolean](_ != 0) _) // booleanToNumeric - check(prop[Boolean, BigDecimal](x => if (x) BigDecimal(1) else BigDecimal(0)) _) + check( + prop[Boolean, BigDecimal](x => if (x) BigDecimal(1) else BigDecimal(0)) _ + ) check(prop[Boolean, Byte](x => if (x) 1 else 0) _) - check(prop[Boolean, Double](x => if (x) 1.0f else 0.0f) _) + check(prop[Boolean, Double](x => if (x) 1.0F else 0.0F) _) check(prop[Boolean, Int](x => if (x) 1 else 0) _) check(prop[Boolean, Long](x => if (x) 1L else 0L) _) check(prop[Boolean, Short](x => if (x) 1 else 0) _) diff --git a/dataset/src/test/scala/frameless/ColTests.scala b/dataset/src/test/scala/frameless/ColTests.scala index ad62aa068..d71174ce7 100644 --- a/dataset/src/test/scala/frameless/ColTests.scala +++ b/dataset/src/test/scala/frameless/ColTests.scala @@ -16,7 +16,10 @@ class ColTests extends TypedDatasetSuite { x4.col[Int]('a) t4.col[Int]('_1) - illTyped("x4.col[String]('a)", "No column .* of type String in frameless.X4.*") + illTyped( + "x4.col[String]('a)", + "No column .* of type String in frameless.X4.*" + ) x4.col('b) t4.col('_2) diff --git a/dataset/src/test/scala/frameless/CollectTests.scala b/dataset/src/test/scala/frameless/CollectTests.scala index 0ff1e6956..56f661961 100644 --- a/dataset/src/test/scala/frameless/CollectTests.scala +++ b/dataset/src/test/scala/frameless/CollectTests.scala @@ -85,10 +85,18 @@ class CollectTests extends TypedDatasetSuite { object CollectTests { import frameless.syntax._ - def prop[A: TypedEncoder : ClassTag](data: Vector[A])(implicit c: SparkSession): Prop = + def prop[A: TypedEncoder: ClassTag]( + data: Vector[A] + )(implicit + c: SparkSession + ): Prop = TypedDataset.create(data).collect().run().toVector ?= data - def propArray[A: TypedEncoder : ClassTag](data: Vector[X1[Array[A]]])(implicit c: SparkSession): Prop = + def propArray[A: TypedEncoder: ClassTag]( + data: Vector[X1[Array[A]]] + )(implicit + c: SparkSession + ): Prop = Prop(TypedDataset.create(data).collect().run().toVector.zip(data).forall { case (X1(l), X1(r)) => l.sameElements(r) }) diff --git a/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala b/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala index 0a9c532a6..ec3f66ac7 100644 --- a/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala +++ b/dataset/src/test/scala/frameless/ColumnViaLambdaTests.scala @@ -11,9 +11,12 @@ case class MyClass4(h: Boolean) final class ColumnViaLambdaTests extends TypedDatasetSuite with Matchers { def ds = { - TypedDataset.create(Seq( - MyClass1(1, "2", MyClass2(3L, MyClass3(7.0D)), Some(MyClass4(true))), - MyClass1(4, "5", MyClass2(6L, MyClass3(8.0D)), None))) + TypedDataset.create( + Seq( + MyClass1(1, "2", MyClass2(3L, MyClass3(7.0D)), Some(MyClass4(true))), + MyClass1(4, "5", MyClass2(6L, MyClass3(8.0D)), None) + ) + ) } test("col(_.a)") { diff --git a/dataset/src/test/scala/frameless/CreateTests.scala b/dataset/src/test/scala/frameless/CreateTests.scala index 4d9b5547d..aebb30b7b 100644 --- a/dataset/src/test/scala/frameless/CreateTests.scala +++ b/dataset/src/test/scala/frameless/CreateTests.scala @@ -1,6 +1,6 @@ package frameless -import org.scalacheck.{Arbitrary, Prop} +import org.scalacheck.{ Arbitrary, Prop } import org.scalacheck.Prop._ import scala.reflect.ClassTag @@ -13,29 +13,40 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("creation using X4 derived DataFrames") { def prop[ - A: TypedEncoder, - B: TypedEncoder, - C: TypedEncoder, - D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = { + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]] + ): Prop = { val ds = TypedDataset.create(data) - TypedDataset.createUnsafe[X4[A, B, C, D]](ds.toDF()).collect().run() ?= data + TypedDataset + .createUnsafe[X4[A, B, C, D]](ds.toDF()) + .collect() + .run() ?= data } check(forAll(prop[Int, Char, X2[Option[Country], Country], Int] _)) check(forAll(prop[X2[Int, Int], Int, Boolean, Vector[Food]] _)) check(forAll(prop[String, Food, X3[Food, Country, Boolean], Int] _)) check(forAll(prop[String, Food, X3U[Food, Country, Boolean], Int] _)) - check(forAll(prop[ - Option[Vector[Food]], - Vector[Vector[X2[Vector[(Person, X1[Char])], Country]]], - X3[Food, Country, String], - Vector[(Food, Country)]] _)) + check( + forAll( + prop[Option[Vector[Food]], Vector[ + Vector[X2[Vector[(Person, X1[Char])], Country]] + ], X3[Food, Country, String], Vector[(Food, Country)]] _ + ) + ) } test("array fields") { def prop[T: Arbitrary: TypedEncoder: ClassTag] = forAll { - (d1: Array[T], d2: Array[Option[T]], d3: Array[X1[T]], d4: Array[X1[Option[T]]], - d5: X1[Array[T]]) => + (d1: Array[T], + d2: Array[Option[T]], + d3: Array[X1[T]], + d4: Array[X1[Option[T]]], + d5: X1[Array[T]] + ) => TypedDataset.create(Seq(d1)).collect().run().head.sameElements(d1) && TypedDataset.create(Seq(d2)).collect().run().head.sameElements(d2) && TypedDataset.create(Seq(d3)).collect().run().head.sameElements(d3) && @@ -55,13 +66,17 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("vector fields") { def prop[T: Arbitrary: TypedEncoder] = forAll { - (d1: Vector[T], d2: Vector[Option[T]], d3: Vector[X1[T]], d4: Vector[X1[Option[T]]], - d5: X1[Vector[T]]) => - (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && - (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && - (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && - (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && - (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) + (d1: Vector[T], + d2: Vector[Option[T]], + d3: Vector[X1[T]], + d4: Vector[X1[Option[T]]], + d5: X1[Vector[T]] + ) => + (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && + (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && + (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && + (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && + (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) } check(prop[Boolean]) @@ -77,9 +92,13 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("list fields") { def prop[T: Arbitrary: TypedEncoder] = forAll { - (d1: List[T], d2: List[Option[T]], d3: List[X1[T]], d4: List[X1[Option[T]]], - d5: X1[List[T]]) => - (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && + (d1: List[T], + d2: List[Option[T]], + d3: List[X1[T]], + d4: List[X1[Option[T]]], + d5: X1[List[T]] + ) => + (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && @@ -98,16 +117,23 @@ class CreateTests extends TypedDatasetSuite with Matchers { } test("Map fields (scala.Predef.Map / scala.collection.immutable.Map)") { - def prop[A: Arbitrary: NotCatalystNullable: TypedEncoder, B: Arbitrary: NotCatalystNullable: TypedEncoder] = forAll { - (d1: Map[A, B], d2: Map[B, A], d3: Map[A, Option[B]], - d4: Map[A, X1[B]], d5: Map[X1[A], B], d6: Map[X1[A], X1[B]]) => - - (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && - (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && - (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && - (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && - (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) && - (TypedDataset.create(Seq(d6)).collect().run().head ?= d6) + def prop[ + A: Arbitrary: NotCatalystNullable: TypedEncoder, + B: Arbitrary: NotCatalystNullable: TypedEncoder + ] = forAll { + (d1: Map[A, B], + d2: Map[B, A], + d3: Map[A, Option[B]], + d4: Map[A, X1[B]], + d5: Map[X1[A], B], + d6: Map[X1[A], X1[B]] + ) => + (TypedDataset.create(Seq(d1)).collect().run().head ?= d1) && + (TypedDataset.create(Seq(d2)).collect().run().head ?= d2) && + (TypedDataset.create(Seq(d3)).collect().run().head ?= d3) && + (TypedDataset.create(Seq(d4)).collect().run().head ?= d4) && + (TypedDataset.create(Seq(d5)).collect().run().head ?= d5) && + (TypedDataset.create(Seq(d6)).collect().run().head ?= d6) } check(prop[String, String]) @@ -123,14 +149,17 @@ class CreateTests extends TypedDatasetSuite with Matchers { test("maps with Option keys should not resolve the TypedEncoder") { val data: Seq[Map[Option[Int], Int]] = Seq(Map(Some(5) -> 5)) - illTyped("TypedDataset.create(data)", ".*could not find implicit value for parameter encoder.*") + illTyped( + "TypedDataset.create(data)", + ".*could not find implicit value for parameter encoder.*" + ) } test("not aligned columns should throw an exception") { - val v = Vector(X2(1,2)) + val v = Vector(X2(1, 2)) val df = TypedDataset.create(v).dataset.toDF() - a [IllegalStateException] should be thrownBy { + a[IllegalStateException] should be thrownBy { TypedDataset.createUnsafe[X1[Int]](df).show().run() } } @@ -139,13 +168,18 @@ class CreateTests extends TypedDatasetSuite with Matchers { // e.g. when loading data from partitioned dataset // the partition columns get appended to the end of the underlying relation def prop[A: Arbitrary: TypedEncoder, B: Arbitrary: TypedEncoder] = forAll { - (a1: A, b1: B) => { - val ds = TypedDataset.create( - Vector((b1, a1)) - ).dataset.toDF("b", "a").as[X2[A, B]](TypedExpressionEncoder[X2[A, B]]) - TypedDataset.create(ds).collect().run().head ?= X2(a1, b1) - - } + (a1: A, b1: B) => + { + val ds = TypedDataset + .create( + Vector((b1, a1)) + ) + .dataset + .toDF("b", "a") + .as[X2[A, B]](TypedExpressionEncoder[X2[A, B]]) + TypedDataset.create(ds).collect().run().head ?= X2(a1, b1) + + } } check(prop[X1[Double], X1[X1[SQLDate]]]) check(prop[String, Int]) diff --git a/dataset/src/test/scala/frameless/DropTest.scala b/dataset/src/test/scala/frameless/DropTest.scala index 3e5a0d739..affd0d8d3 100644 --- a/dataset/src/test/scala/frameless/DropTest.scala +++ b/dataset/src/test/scala/frameless/DropTest.scala @@ -8,28 +8,36 @@ class DropTest extends TypedDatasetSuite { import DropTest._ test("fail to compile on missing value") { - val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil) + val f: TypedDataset[X] = TypedDataset.create( + X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil + ) illTyped { """val fNew: TypedDataset[XMissing] = f.drop[XMissing]('j)""" } } test("fail to compile on different column name") { - val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil) + val f: TypedDataset[X] = TypedDataset.create( + X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil + ) illTyped { """val fNew: TypedDataset[XDifferentColumnName] = f.drop[XDifferentColumnName]('j)""" } } test("fail to compile on added column name") { - val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil) + val f: TypedDataset[X] = TypedDataset.create( + X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil + ) illTyped { """val fNew: TypedDataset[XAdded] = f.drop[XAdded]('j)""" } } test("remove column in the middle") { - val f: TypedDataset[X] = TypedDataset.create(X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil) + val f: TypedDataset[X] = TypedDataset.create( + X(1, 1, false) :: X(1, 1, false) :: X(1, 10, false) :: Nil + ) val fNew: TypedDataset[XGood] = f.drop[XGood] fNew.collect().run().foreach(xg => assert(xg === XGood(1, false))) diff --git a/dataset/src/test/scala/frameless/DropTupledTest.scala b/dataset/src/test/scala/frameless/DropTupledTest.scala index ff0158b91..d23b8a640 100644 --- a/dataset/src/test/scala/frameless/DropTupledTest.scala +++ b/dataset/src/test/scala/frameless/DropTupledTest.scala @@ -7,9 +7,9 @@ class DropTupledTest extends TypedDatasetSuite { test("drop five columns") { def prop[A: TypedEncoder](value: A): Prop = { val d5 = TypedDataset.create(X5(value, value, value, value, value) :: Nil) - val d4 = d5.dropTupled('a) //drops first column - val d3 = d4.dropTupled('_4) //drops last column - val d2 = d3.dropTupled('_2) //drops middle column + val d4 = d5.dropTupled('a) // drops first column + val d3 = d4.dropTupled('_4) // drops last column + val d2 = d3.dropTupled('_2) // drops middle column val d1 = d2.dropTupled('_2) Tuple1(value) ?= d1.collect().run().head diff --git a/dataset/src/test/scala/frameless/ExplodeTests.scala b/dataset/src/test/scala/frameless/ExplodeTests.scala index 3078ceb12..090f95522 100644 --- a/dataset/src/test/scala/frameless/ExplodeTests.scala +++ b/dataset/src/test/scala/frameless/ExplodeTests.scala @@ -1,7 +1,7 @@ package frameless import frameless.functions.CatalystExplodableCollection -import org.scalacheck.{Arbitrary, Prop} +import org.scalacheck.{ Arbitrary, Prop } import org.scalacheck.Prop.forAll import org.scalacheck.Prop._ @@ -9,12 +9,19 @@ import scala.reflect.ClassTag class ExplodeTests extends TypedDatasetSuite { test("simple explode test") { - val ds = TypedDataset.create(Seq((1,Array(1,2)))) - ds.explode('_2): TypedDataset[(Int,Int)] + val ds = TypedDataset.create(Seq((1, Array(1, 2)))) + ds.explode('_2): TypedDataset[(Int, Int)] } test("explode on vectors/list/seq") { - def prop[F[X] <: Traversable[X] : CatalystExplodableCollection, A: TypedEncoder](xs: List[X1[F[A]]])(implicit arb: Arbitrary[F[A]], enc: TypedEncoder[F[A]]): Prop = { + def prop[ + F[X] <: Traversable[X]: CatalystExplodableCollection, + A: TypedEncoder + ](xs: List[X1[F[A]]] + )(implicit + arb: Arbitrary[F[A]], + enc: TypedEncoder[F[A]] + ): Prop = { val tds = TypedDataset.create(xs) val framelessResults = tds.explode('a).collect().run().toVector @@ -49,11 +56,14 @@ class ExplodeTests extends TypedDatasetSuite { } test("explode on maps") { - def prop[A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X1[Map[A, B]]]): Prop = { + def prop[A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag]( + xs: List[X1[Map[A, B]]] + ): Prop = { val tds = TypedDataset.create(xs) val framelessResults = tds.explodeMap('a).collect().run().toVector - val scalaResults = xs.flatMap(_.a.toList).map(t => Tuple1(Tuple2(t._1, t._2))).toVector + val scalaResults = + xs.flatMap(_.a.toList).map(t => Tuple1(Tuple2(t._1, t._2))).toVector framelessResults ?= scalaResults } @@ -64,11 +74,18 @@ class ExplodeTests extends TypedDatasetSuite { } test("explode on maps preserving other columns") { - def prop[K: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X2[K, Map[A, B]]]): Prop = { + def prop[ + K: TypedEncoder: ClassTag, + A: TypedEncoder: ClassTag, + B: TypedEncoder: ClassTag + ](xs: List[X2[K, Map[A, B]]] + ): Prop = { val tds = TypedDataset.create(xs) val framelessResults = tds.explodeMap('b).collect().run().toVector - val scalaResults = xs.flatMap { x2 => x2.b.toList.map((x2.a, _)) }.toVector + val scalaResults = xs.flatMap { x2 => + x2.b.toList.map((x2.a, _)) + }.toVector framelessResults ?= scalaResults } @@ -79,11 +96,19 @@ class ExplodeTests extends TypedDatasetSuite { } test("explode on maps making sure no key / value naming collision happens") { - def prop[K: TypedEncoder: ClassTag, V: TypedEncoder: ClassTag, A: TypedEncoder: ClassTag, B: TypedEncoder: ClassTag](xs: List[X3KV[K, V, Map[A, B]]]): Prop = { + def prop[ + K: TypedEncoder: ClassTag, + V: TypedEncoder: ClassTag, + A: TypedEncoder: ClassTag, + B: TypedEncoder: ClassTag + ](xs: List[X3KV[K, V, Map[A, B]]] + ): Prop = { val tds = TypedDataset.create(xs) val framelessResults = tds.explodeMap('c).collect().run().toVector - val scalaResults = xs.flatMap { x3 => x3.c.toList.map((x3.key, x3.value, _)) }.toVector + val scalaResults = xs.flatMap { x3 => + x3.c.toList.map((x3.key, x3.value, _)) + }.toVector framelessResults ?= scalaResults } diff --git a/dataset/src/test/scala/frameless/FilterTests.scala b/dataset/src/test/scala/frameless/FilterTests.scala index 56d5d2ec5..c93660ca9 100644 --- a/dataset/src/test/scala/frameless/FilterTests.scala +++ b/dataset/src/test/scala/frameless/FilterTests.scala @@ -7,7 +7,12 @@ import org.scalacheck.Prop._ final class FilterTests extends TypedDatasetSuite with Matchers { test("filter('a == lit(b))") { - def prop[A: TypedEncoder](elem: A, data: Vector[X1[A]])(implicit ex1: TypedEncoder[X1[A]]): Prop = { + def prop[A: TypedEncoder]( + elem: A, + data: Vector[X1[A]] + )(implicit + ex1: TypedEncoder[X1[A]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col('a) @@ -22,7 +27,12 @@ final class FilterTests extends TypedDatasetSuite with Matchers { } test("filter('a =!= lit(b))") { - def prop[A: TypedEncoder](elem: A, data: Vector[X1[A]])(implicit ex1: TypedEncoder[X1[A]]): Prop = { + def prop[A: TypedEncoder]( + elem: A, + data: Vector[X1[A]] + )(implicit + ex1: TypedEncoder[X1[A]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col('a) @@ -61,13 +71,13 @@ final class FilterTests extends TypedDatasetSuite with Matchers { } test("filter('a =!= 'b") { - def prop[A: TypedEncoder](elem: A, data: Vector[X2[A,A]]): Prop = { + def prop[A: TypedEncoder](elem: A, data: Vector[X2[A, A]]): Prop = { val dataset = TypedDataset.create(data) val cA = dataset.col('a) val cB = dataset.col('b) val dataset2 = dataset.filter(cA =!= cB).collect().run().toVector - val data2 = data.filter(x => x.a != x.b ) + val data2 = data.filter(x => x.a != x.b) (dataset2 ?= data2).&&(dataset.filter(cA =!= cA).count().run() ?= 0) } @@ -82,7 +92,8 @@ final class FilterTests extends TypedDatasetSuite with Matchers { test("filter with arithmetic expressions: addition") { check(forAll { (data: Vector[X1[Int]]) => val ds = TypedDataset.create(data) - val res = ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector + val res = + ds.filter((ds('a) + 1) === (ds('a) + 1)).collect().run().toVector res ?= data }) } @@ -99,21 +110,41 @@ final class FilterTests extends TypedDatasetSuite with Matchers { val t = X1(1) :: X1(2) :: X1(3) :: Nil val tds: TypedDataset[X1[Int]] = TypedDataset.create(t) - assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1))) - assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1))) + assert( + tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1)) + ) + assert( + tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1)) + ) } test("Option equality/inequality for columns") { - def prop[A <: Option[_] : TypedEncoder](a: A, b: A): Prop = { + def prop[A <: Option[_]: TypedEncoder](a: A, b: A): Prop = { val data = X2(a, b) :: X2(a, a) :: Nil val dataset = TypedDataset.create(data) val A = dataset.col('a) val B = dataset.col('b) - (data.filter(x => x.a == x.b).toSet ?= dataset.filter(A === B).collect().run().toSet). - &&(data.filter(x => x.a != x.b).toSet ?= dataset.filter(A =!= B).collect().run().toSet). - &&(data.filter(x => x.a == None).toSet ?= dataset.filter(A.isNone).collect().run().toSet). - &&(data.filter(x => x.a == None).toSet ?= dataset.filter(A.isNotNone === false).collect().run().toSet) + (data + .filter(x => x.a == x.b) + .toSet ?= dataset.filter(A === B).collect().run().toSet) + .&&( + data + .filter(x => x.a != x.b) + .toSet ?= dataset.filter(A =!= B).collect().run().toSet + ) + .&&( + data + .filter(x => x.a == None) + .toSet ?= dataset.filter(A.isNone).collect().run().toSet + ) + .&&( + data.filter(x => x.a == None).toSet ?= dataset + .filter(A.isNotNone === false) + .collect() + .run() + .toSet + ) } check(forAll(prop[Option[Int]] _)) @@ -126,15 +157,31 @@ final class FilterTests extends TypedDatasetSuite with Matchers { } test("Option equality/inequality for lit") { - def prop[A <: Option[_] : TypedEncoder](a: A, b: A, cLit: A): Prop = { + def prop[A <: Option[_]: TypedEncoder](a: A, b: A, cLit: A): Prop = { val data = X2(a, b) :: X2(a, cLit) :: Nil val dataset = TypedDataset.create(data) val colA = dataset.col('a) - (data.filter(x => x.a == cLit).toSet ?= dataset.filter(colA === cLit).collect().run().toSet). - &&(data.filter(x => x.a != cLit).toSet ?= dataset.filter(colA =!= cLit).collect().run().toSet). - &&(data.filter(x => x.a == None).toSet ?= dataset.filter(colA.isNone).collect().run().toSet). - &&(data.filter(x => x.a == None).toSet ?= dataset.filter(colA.isNotNone === false).collect().run().toSet) + (data + .filter(x => x.a == cLit) + .toSet ?= dataset.filter(colA === cLit).collect().run().toSet) + .&&( + data + .filter(x => x.a != cLit) + .toSet ?= dataset.filter(colA =!= cLit).collect().run().toSet + ) + .&&( + data + .filter(x => x.a == None) + .toSet ?= dataset.filter(colA.isNone).collect().run().toSet + ) + .&&( + data.filter(x => x.a == None).toSet ?= dataset + .filter(colA.isNotNone === false) + .collect() + .run() + .toSet + ) } check(forAll(prop[Option[Int]] _)) @@ -148,7 +195,10 @@ final class FilterTests extends TypedDatasetSuite with Matchers { } test("Option content filter") { - val data = (Option(1L), Option(2L)) :: (Option(0L), Option(1L)) :: (None, None) :: Nil + val data = (Option(1L), Option(2L)) :: (Option(0L), Option(1L)) :: ( + None, + None + ) :: Nil val ds = TypedDataset.create(data) @@ -162,13 +212,20 @@ final class FilterTests extends TypedDatasetSuite with Matchers { ds.filter(exists).collect().run() shouldEqual Seq(Option(0L) -> Option(1L)) ds.filter(forall).collect().run() shouldEqual Seq( - Option(0L) -> Option(1L), (None -> None)) + Option(0L) -> Option(1L), + (None -> None) + ) } test("filter with isin values") { - def prop[A: TypedEncoder](data: Vector[X1[A]], values: Vector[A])(implicit a : CatalystIsin[A]): Prop = { + def prop[A: TypedEncoder]( + data: Vector[X1[A]], + values: Vector[A] + )(implicit + a: CatalystIsin[A] + ): Prop = { val ds = TypedDataset.create(data) - val res = ds.filter(ds('a).isin(values:_*)).collect().run().toVector + val res = ds.filter(ds('a).isin(values: _*)).collect().run().toVector res ?= data.filter(d => values.contains(d.a)) } diff --git a/dataset/src/test/scala/frameless/FlattenTests.scala b/dataset/src/test/scala/frameless/FlattenTests.scala index a65e51b8f..915eb67de 100644 --- a/dataset/src/test/scala/frameless/FlattenTests.scala +++ b/dataset/src/test/scala/frameless/FlattenTests.scala @@ -4,18 +4,19 @@ import org.scalacheck.Prop import org.scalacheck.Prop.forAll import org.scalacheck.Prop._ - class FlattenTests extends TypedDatasetSuite { test("simple flatten test") { - val ds: TypedDataset[(Int,Option[Int])] = TypedDataset.create(Seq((1,Option(1)))) - ds.flattenOption('_2): TypedDataset[(Int,Int)] + val ds: TypedDataset[(Int, Option[Int])] = + TypedDataset.create(Seq((1, Option(1)))) + ds.flattenOption('_2): TypedDataset[(Int, Int)] } test("different Optional types") { def prop[A: TypedEncoder](xs: List[X1[Option[A]]]): Prop = { val tds: TypedDataset[X1[Option[A]]] = TypedDataset.create(xs) - val framelessResults: Seq[Tuple1[A]] = tds.flattenOption('a).collect().run().toVector + val framelessResults: Seq[Tuple1[A]] = + tds.flattenOption('a).collect().run().toVector val scalaResults = xs.flatMap(_.a).map(Tuple1(_)).toVector framelessResults ?= scalaResults diff --git a/dataset/src/test/scala/frameless/GroupByTests.scala b/dataset/src/test/scala/frameless/GroupByTests.scala index 7178def30..0bc96eb20 100644 --- a/dataset/src/test/scala/frameless/GroupByTests.scala +++ b/dataset/src/test/scala/frameless/GroupByTests.scala @@ -7,20 +7,25 @@ import org.scalacheck.Prop._ class GroupByTests extends TypedDatasetSuite { test("groupByMany('a).agg(sum('b))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder, - Out: TypedEncoder : Numeric - ](data: List[X2[A, B]])( - implicit - summable: CatalystSummable[B, Out], - widen: B => Out - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out], + widen: B => Out + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val datasetSumByA = dataset.groupByMany(A).agg(sum(B)).collect().run.toVector.sortBy(_._1) - val sumByA = data.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).map(widen).sum }.toVector.sortBy(_._1) + val datasetSumByA = + dataset.groupByMany(A).agg(sum(B)).collect().run.toVector.sortBy(_._1) + val sumByA = data + .groupBy(_.a) + .map { case (k, v) => k -> v.map(_.b).map(widen).sum } + .toVector + .sortBy(_._1) datasetSumByA ?= sumByA } @@ -29,10 +34,11 @@ class GroupByTests extends TypedDatasetSuite { } test("agg(sum('a))") { - def prop[A: TypedEncoder : Numeric](data: List[X1[A]])( - implicit - summable: CatalystSummable[A, A] - ): Prop = { + def prop[A: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, A] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) @@ -46,14 +52,12 @@ class GroupByTests extends TypedDatasetSuite { } test("agg(sum('a), sum('b))") { - def prop[ - A: TypedEncoder : Numeric, - B: TypedEncoder : Numeric - ](data: List[X2[A, B]])( - implicit - as: CatalystSummable[A, A], - bs: CatalystSummable[B, B] - ): Prop = { + def prop[A: TypedEncoder: Numeric, B: TypedEncoder: Numeric]( + data: List[X2[A, B]] + )(implicit + as: CatalystSummable[A, A], + bs: CatalystSummable[B, B] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -70,21 +74,22 @@ class GroupByTests extends TypedDatasetSuite { test("agg(sum('a), sum('b), sum('c))") { def prop[ - A: TypedEncoder : Numeric, - B: TypedEncoder : Numeric, - C: TypedEncoder : Numeric - ](data: List[X3[A, B, C]])( - implicit - as: CatalystSummable[A, A], - bs: CatalystSummable[B, B], - cs: CatalystSummable[C, C] - ): Prop = { + A: TypedEncoder: Numeric, + B: TypedEncoder: Numeric, + C: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + )(implicit + as: CatalystSummable[A, A], + bs: CatalystSummable[B, B], + cs: CatalystSummable[C, C] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) - val datasetSum = dataset.agg(sum(A), sum(B), sum(C)).collect().run().toVector + val datasetSum = + dataset.agg(sum(A), sum(B), sum(C)).collect().run().toVector val listSumA = data.map(_.a).sum val listSumB = data.map(_.b).sum val listSumC = data.map(_.c).sum @@ -97,30 +102,37 @@ class GroupByTests extends TypedDatasetSuite { test("agg(sum('a), sum('b), min('c), max('d))") { def prop[ - A: TypedEncoder : Numeric, - B: TypedEncoder : Numeric, - C: TypedEncoder : Numeric, - D: TypedEncoder : Numeric - ](data: List[X4[A, B, C, D]])( - implicit - as: CatalystSummable[A, A], - bs: CatalystSummable[B, B], - co: CatalystOrdered[C], - fo: CatalystOrdered[D] - ): Prop = { + A: TypedEncoder: Numeric, + B: TypedEncoder: Numeric, + C: TypedEncoder: Numeric, + D: TypedEncoder: Numeric + ](data: List[X4[A, B, C, D]] + )(implicit + as: CatalystSummable[A, A], + bs: CatalystSummable[B, B], + co: CatalystOrdered[C], + fo: CatalystOrdered[D] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) val D = dataset.col[D]('d) - val datasetSum = dataset.agg(sum(A), sum(B), min(C), max(D)).collect().run().toVector + val datasetSum = + dataset.agg(sum(A), sum(B), min(C), max(D)).collect().run().toVector val listSumA = data.map(_.a).sum val listSumB = data.map(_.b).sum - val listMinC = if(data.isEmpty) implicitly[Numeric[C]].fromInt(0) else data.map(_.c).min - val listMaxD = if(data.isEmpty) implicitly[Numeric[D]].fromInt(0) else data.map(_.d).max - - datasetSum ?= Vector(if (data.isEmpty) null else (listSumA, listSumB, listMinC, listMaxD)) + val listMinC = + if (data.isEmpty) implicitly[Numeric[C]].fromInt(0) + else data.map(_.c).min + val listMaxD = + if (data.isEmpty) implicitly[Numeric[D]].fromInt(0) + else data.map(_.d).max + + datasetSum ?= Vector( + if (data.isEmpty) null else (listSumA, listSumB, listMinC, listMaxD) + ) } check(forAll(prop[Long, Long, Long, Int] _)) @@ -130,20 +142,25 @@ class GroupByTests extends TypedDatasetSuite { test("groupBy('a).agg(sum('b))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder, - Out: TypedEncoder : Numeric - ](data: List[X2[A, B]])( - implicit - summable: CatalystSummable[B, Out], - widen: B => Out - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out], + widen: B => Out + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val datasetSumByA = dataset.groupBy(A).agg(sum(B)).collect().run.toVector.sortBy(_._1) - val sumByA = data.groupBy(_.a).mapValues(_.map(_.b).map(widen).sum).toVector.sortBy(_._1) + val datasetSumByA = + dataset.groupBy(A).agg(sum(B)).collect().run.toVector.sortBy(_._1) + val sumByA = data + .groupBy(_.a) + .mapValues(_.map(_.b).map(widen).sum) + .toVector + .sortBy(_._1) datasetSumByA ?= sumByA } @@ -152,18 +169,23 @@ class GroupByTests extends TypedDatasetSuite { } test("groupBy('a).mapGroups('a, sum('b))") { - def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Numeric - ](data: List[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric]( + data: List[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val datasetSumByA = dataset.groupBy(A) - .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } - .collect().run().toVector.sortBy(_._1) - val sumByA = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) + val datasetSumByA = dataset + .groupBy(A) + .deserialized + .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } + .collect() + .run() + .toVector + .sortBy(_._1) + val sumByA = + data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) datasetSumByA ?= sumByA } @@ -173,18 +195,18 @@ class GroupByTests extends TypedDatasetSuite { test("groupBy('a).agg(sum('b), sum('c)) to groupBy('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder, - C: TypedEncoder, - OutB: TypedEncoder : Numeric, - OutC: TypedEncoder : Numeric - ](data: List[X3[A, B, C]])( - implicit - summableB: CatalystSummable[B, OutB], - summableC: CatalystSummable[C, OutC], - widenb: B => OutB, - widenc: C => OutC - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder, + C: TypedEncoder, + OutB: TypedEncoder: Numeric, + OutC: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + )(implicit + summableB: CatalystSummable[B, OutB], + summableC: CatalystSummable[C, OutC], + widenb: B => OutB, + widenc: C => OutC + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -193,46 +215,85 @@ class GroupByTests extends TypedDatasetSuite { val framelessSumBC = dataset .groupBy(A) .agg(sum(B), sum(C)) - .collect().run.toVector.sortBy(_._1) - - val scalaSumBC = data.groupBy(_.a).mapValues { xs => - (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum) - }.toVector.map { - case (a, (b, c)) => (a, b, c) - }.sortBy(_._1) + .collect() + .run + .toVector + .sortBy(_._1) + + val scalaSumBC = data + .groupBy(_.a) + .mapValues { xs => + (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum) + } + .toVector + .map { case (a, (b, c)) => (a, b, c) } + .sortBy(_._1) val framelessSumBCB = dataset .groupBy(A) .agg(sum(B), sum(C), sum(B)) - .collect().run.toVector.sortBy(_._1) - - val scalaSumBCB = data.groupBy(_.a).mapValues { xs => - (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum) - }.toVector.map { - case (a, (b1, c, b2)) => (a, b1, c, b2) - }.sortBy(_._1) + .collect() + .run + .toVector + .sortBy(_._1) + + val scalaSumBCB = data + .groupBy(_.a) + .mapValues { xs => + ( + xs.map(_.b).map(widenb).sum, + xs.map(_.c).map(widenc).sum, + xs.map(_.b).map(widenb).sum + ) + } + .toVector + .map { case (a, (b1, c, b2)) => (a, b1, c, b2) } + .sortBy(_._1) val framelessSumBCBC = dataset .groupBy(A) .agg(sum(B), sum(C), sum(B), sum(C)) - .collect().run.toVector.sortBy(_._1) - - val scalaSumBCBC = data.groupBy(_.a).mapValues { xs => - (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum) - }.toVector.map { - case (a, (b1, c1, b2, c2)) => (a, b1, c1, b2, c2) - }.sortBy(_._1) + .collect() + .run + .toVector + .sortBy(_._1) + + val scalaSumBCBC = data + .groupBy(_.a) + .mapValues { xs => + ( + xs.map(_.b).map(widenb).sum, + xs.map(_.c).map(widenc).sum, + xs.map(_.b).map(widenb).sum, + xs.map(_.c).map(widenc).sum + ) + } + .toVector + .map { case (a, (b1, c1, b2, c2)) => (a, b1, c1, b2, c2) } + .sortBy(_._1) val framelessSumBCBCB = dataset .groupBy(A) .agg(sum(B), sum(C), sum(B), sum(C), sum(B)) - .collect().run.toVector.sortBy(_._1) - - val scalaSumBCBCB = data.groupBy(_.a).mapValues { xs => - (xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum, xs.map(_.c).map(widenc).sum, xs.map(_.b).map(widenb).sum) - }.toVector.map { - case (a, (b1, c1, b2, c2, b3)) => (a, b1, c1, b2, c2, b3) - }.sortBy(_._1) + .collect() + .run + .toVector + .sortBy(_._1) + + val scalaSumBCBCB = data + .groupBy(_.a) + .mapValues { xs => + ( + xs.map(_.b).map(widenb).sum, + xs.map(_.c).map(widenc).sum, + xs.map(_.b).map(widenb).sum, + xs.map(_.c).map(widenc).sum, + xs.map(_.b).map(widenb).sum + ) + } + .toVector + .map { case (a, (b1, c1, b2, c2, b3)) => (a, b1, c1, b2, c2, b3) } + .sortBy(_._1) (framelessSumBC ?= scalaSumBC) .&&(framelessSumBCB ?= scalaSumBCB) @@ -245,70 +306,110 @@ class GroupByTests extends TypedDatasetSuite { test("groupBy('a, 'b).agg(sum('c)) to groupBy('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - OutC: TypedEncoder: Numeric - ](data: List[X3[A, B, C]])( - implicit - summableC: CatalystSummable[C, OutC], - widenc: C => OutC - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + OutC: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + )(implicit + summableC: CatalystSummable[C, OutC], + widenc: C => OutC + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) val framelessSumC = dataset - .groupBy(A,B) + .groupBy(A, B) .agg(sum(C)) - .collect().run.toVector.sortBy(x => (x._1,x._2)) - - val scalaSumC = data.groupBy(x => (x.a,x.b)).mapValues { xs => - xs.map(_.c).map(widenc).sum - }.toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1,x._2)) + .collect() + .run + .toVector + .sortBy(x => (x._1, x._2)) + + val scalaSumC = data + .groupBy(x => (x.a, x.b)) + .mapValues { xs => xs.map(_.c).map(widenc).sum } + .toVector + .map { case ((a, b), c) => (a, b, c) } + .sortBy(x => (x._1, x._2)) val framelessSumCC = dataset - .groupBy(A,B) + .groupBy(A, B) .agg(sum(C), sum(C)) - .collect().run.toVector.sortBy(x => (x._1,x._2)) - - val scalaSumCC = data.groupBy(x => (x.a,x.b)).mapValues { xs => - val s = xs.map(_.c).map(widenc).sum; (s,s) - }.toVector.map { case ((a, b), (c1, c2)) => (a, b, c1, c2) }.sortBy(x => (x._1,x._2)) + .collect() + .run + .toVector + .sortBy(x => (x._1, x._2)) + + val scalaSumCC = data + .groupBy(x => (x.a, x.b)) + .mapValues { xs => + val s = xs.map(_.c).map(widenc).sum; (s, s) + } + .toVector + .map { case ((a, b), (c1, c2)) => (a, b, c1, c2) } + .sortBy(x => (x._1, x._2)) val framelessSumCCC = dataset - .groupBy(A,B) + .groupBy(A, B) .agg(sum(C), sum(C), sum(C)) - .collect().run.toVector.sortBy(x => (x._1,x._2)) - - val scalaSumCCC = data.groupBy(x => (x.a,x.b)).mapValues { xs => - val s = xs.map(_.c).map(widenc).sum; (s,s,s) - }.toVector.map { case ((a, b), (c1, c2, c3)) => (a, b, c1, c2, c3) }.sortBy(x => (x._1,x._2)) + .collect() + .run + .toVector + .sortBy(x => (x._1, x._2)) + + val scalaSumCCC = data + .groupBy(x => (x.a, x.b)) + .mapValues { xs => + val s = xs.map(_.c).map(widenc).sum; (s, s, s) + } + .toVector + .map { case ((a, b), (c1, c2, c3)) => (a, b, c1, c2, c3) } + .sortBy(x => (x._1, x._2)) val framelessSumCCCC = dataset - .groupBy(A,B) + .groupBy(A, B) .agg(sum(C), sum(C), sum(C), sum(C)) - .collect().run.toVector.sortBy(x => (x._1,x._2)) - - val scalaSumCCCC = data.groupBy(x => (x.a,x.b)).mapValues { xs => - val s = xs.map(_.c).map(widenc).sum; (s,s,s,s) - }.toVector.map { case ((a, b), (c1, c2, c3, c4)) => (a, b, c1, c2, c3, c4) }.sortBy(x => (x._1,x._2)) + .collect() + .run + .toVector + .sortBy(x => (x._1, x._2)) + + val scalaSumCCCC = data + .groupBy(x => (x.a, x.b)) + .mapValues { xs => + val s = xs.map(_.c).map(widenc).sum; (s, s, s, s) + } + .toVector + .map { case ((a, b), (c1, c2, c3, c4)) => (a, b, c1, c2, c3, c4) } + .sortBy(x => (x._1, x._2)) val framelessSumCCCCC = dataset - .groupBy(A,B) + .groupBy(A, B) .agg(sum(C), sum(C), sum(C), sum(C), sum(C)) - .collect().run.toVector.sortBy(x => (x._1,x._2)) - - val scalaSumCCCCC = data.groupBy(x => (x.a,x.b)).mapValues { xs => - val s = xs.map(_.c).map(widenc).sum; (s,s,s,s,s) - }.toVector.map { case ((a, b), (c1, c2, c3, c4, c5)) => (a, b, c1, c2, c3, c4, c5) }.sortBy(x => (x._1,x._2)) + .collect() + .run + .toVector + .sortBy(x => (x._1, x._2)) + + val scalaSumCCCCC = data + .groupBy(x => (x.a, x.b)) + .mapValues { xs => + val s = xs.map(_.c).map(widenc).sum; (s, s, s, s, s) + } + .toVector + .map { + case ((a, b), (c1, c2, c3, c4, c5)) => (a, b, c1, c2, c3, c4, c5) + } + .sortBy(x => (x._1, x._2)) (framelessSumC ?= scalaSumC) && - (framelessSumCC ?= scalaSumCC) && - (framelessSumCCC ?= scalaSumCCC) && - (framelessSumCCCC ?= scalaSumCCCC) && - (framelessSumCCCCC ?= scalaSumCCCCC) + (framelessSumCC ?= scalaSumCC) && + (framelessSumCCC ?= scalaSumCCC) && + (framelessSumCCCC ?= scalaSumCCCC) && + (framelessSumCCCCC ?= scalaSumCCCCC) } check(forAll(prop[String, Long, BigDecimal, BigDecimal] _)) @@ -316,19 +417,19 @@ class GroupByTests extends TypedDatasetSuite { test("groupBy('a, 'b).agg(sum('c), sum('d))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - D: TypedEncoder, - OutC: TypedEncoder : Numeric, - OutD: TypedEncoder : Numeric - ](data: List[X4[A, B, C, D]])( - implicit - summableC: CatalystSummable[C, OutC], - summableD: CatalystSummable[D, OutD], - widenc: C => OutC, - widend: D => OutD - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + D: TypedEncoder, + OutC: TypedEncoder: Numeric, + OutD: TypedEncoder: Numeric + ](data: List[X4[A, B, C, D]] + )(implicit + summableC: CatalystSummable[C, OutC], + summableD: CatalystSummable[D, OutD], + widenc: C => OutC, + widend: D => OutD + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -338,13 +439,19 @@ class GroupByTests extends TypedDatasetSuite { val datasetSumByAB = dataset .groupBy(A, B) .agg(sum(C), sum(D)) - .collect().run.toVector.sortBy(x => (x._1, x._2)) - - val sumByAB = data.groupBy(x => (x.a, x.b)).mapValues { xs => - (xs.map(_.c).map(widenc).sum, xs.map(_.d).map(widend).sum) - }.toVector.map { - case ((a, b), (c, d)) => (a, b, c, d) - }.sortBy(x => (x._1, x._2)) + .collect() + .run + .toVector + .sortBy(x => (x._1, x._2)) + + val sumByAB = data + .groupBy(x => (x.a, x.b)) + .mapValues { xs => + (xs.map(_.c).map(widenc).sum, xs.map(_.d).map(widend).sum) + } + .toVector + .map { case ((a, b), (c, d)) => (a, b, c, d) } + .sortBy(x => (x._1, x._2)) datasetSumByAB ?= sumByAB } @@ -354,10 +461,11 @@ class GroupByTests extends TypedDatasetSuite { test("groupBy('a, 'b).mapGroups('a, 'b, sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Numeric - ](data: List[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -365,12 +473,19 @@ class GroupByTests extends TypedDatasetSuite { val datasetSumByAB = dataset .groupBy(A, B) - .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } - .collect().run().toVector.sortBy(x => (x._1, x._2)) - - val sumByAB = data.groupBy(x => (x.a, x.b)) + .deserialized + .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } + .collect() + .run() + .toVector + .sortBy(x => (x._1, x._2)) + + val sumByAB = data + .groupBy(x => (x.a, x.b)) .mapValues { xs => xs.map(_.c).sum } - .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2)) + .toVector + .map { case ((a, b), c) => (a, b, c) } + .sortBy(x => (x._1, x._2)) datasetSumByAB ?= sumByAB } @@ -379,17 +494,19 @@ class GroupByTests extends TypedDatasetSuite { } test("groupBy('a).mapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder: Ordering, - B: TypedEncoder: Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .groupBy(A) - .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted)) - .collect().run.toMap + .deserialized + .mapGroups((a, xs) => (a, xs.toVector.sorted)) + .collect() + .run + .toMap val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted } @@ -402,21 +519,23 @@ class GroupByTests extends TypedDatasetSuite { } test("groupBy('a).flatMapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .groupBy(A) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run .sorted val dataGrouped = data - .groupBy(_.a).toSeq + .groupBy(_.a) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -430,22 +549,26 @@ class GroupByTests extends TypedDatasetSuite { test("groupBy('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Ordering - ](data: Vector[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](data: Vector[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val cA = dataset.col[A]('a) val cB = dataset.col[B]('b) val datasetGrouped = dataset .groupBy(cA, cB) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(t => (t.a,t.b)).toSeq + .groupBy(t => (t.a, t.b)) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted diff --git a/dataset/src/test/scala/frameless/InjectionTests.scala b/dataset/src/test/scala/frameless/InjectionTests.scala index c17a52bd7..9ee252616 100644 --- a/dataset/src/test/scala/frameless/InjectionTests.scala +++ b/dataset/src/test/scala/frameless/InjectionTests.scala @@ -10,6 +10,7 @@ case object France extends Country case object Russia extends Country object Country { + implicit val arbitrary: Arbitrary[Country] = Arbitrary(Arbitrary.arbitrary[Boolean].map(injection.invert)) @@ -23,15 +24,18 @@ case object Pasta extends Food case object Rice extends Food object Food { + implicit val arbitrary: Arbitrary[Food] = - Arbitrary(Arbitrary.arbitrary[Int].map(i => injection.invert(Math.abs(i % 3)))) + Arbitrary( + Arbitrary.arbitrary[Int].map(i => injection.invert(Math.abs(i % 3))) + ) implicit val injection: Injection[Food, Int] = Injection( { case Burger => 0 - case Pasta => 1 - case Rice => 2 + case Pasta => 1 + case Rice => 2 }, { case 0 => Burger @@ -46,10 +50,13 @@ class LocalDateTime { var instant: Long = _ override def equals(o: Any): Boolean = - o.isInstanceOf[LocalDateTime] && o.asInstanceOf[LocalDateTime].instant == instant + o.isInstanceOf[LocalDateTime] && o + .asInstanceOf[LocalDateTime] + .instant == instant } object LocalDateTime { + implicit val arbitrary: Arbitrary[LocalDateTime] = Arbitrary(Arbitrary.arbitrary[Long].map(injection.invert)) @@ -76,8 +83,12 @@ case class I[A](value: A) object I { implicit def injection[A]: Injection[I[A], A] = Injection(_.value, I(_)) - implicit def typedEncoder[A: TypedEncoder]: TypedEncoder[I[A]] = TypedEncoder.usingInjection[I[A], A] - implicit def arbitrary[A: Arbitrary]: Arbitrary[I[A]] = Arbitrary(Arbitrary.arbitrary[A].map(I(_))) + + implicit def typedEncoder[A: TypedEncoder]: TypedEncoder[I[A]] = + TypedEncoder.usingInjection[I[A], A] + + implicit def arbitrary[A: Arbitrary]: Arbitrary[I[A]] = + Arbitrary(Arbitrary.arbitrary[A].map(I(_))) } sealed trait Employee @@ -86,6 +97,7 @@ case object PartTime extends Employee case object FullTime extends Employee object Employee { + implicit val arbitrary: Arbitrary[Employee] = Arbitrary(Gen.oneOf(Casual, PartTime, FullTime)) } @@ -95,6 +107,7 @@ case object Nothing extends Maybe case class Just(get: Int) extends Maybe sealed trait Switch + object Switch { case object Off extends Switch case object On extends Switch @@ -109,6 +122,7 @@ case class Green() extends Pixel case class Blue() extends Pixel object Pixel { + implicit val arbitrary: Arbitrary[Pixel] = Arbitrary(Gen.oneOf(Red(), Green(), Blue())) } @@ -118,6 +132,7 @@ case object Closed extends Connection[Nothing] case object Open extends Connection[Nothing] object Connection { + implicit def arbitrary[A]: Arbitrary[Connection[A]] = Arbitrary(Gen.oneOf(Closed, Open)) } @@ -127,6 +142,7 @@ case object Car extends Vehicle("red") case object Bike extends Vehicle("blue") object Vehicle { + implicit val arbitrary: Arbitrary[Vehicle] = Arbitrary(Gen.oneOf(Car, Bike)) } @@ -159,7 +175,9 @@ class InjectionTests extends TypedDatasetSuite { check(forAll(prop[Option[I[X1[Int]]]] _)) assert(TypedEncoder[I[Int]].catalystRepr == TypedEncoder[Int].catalystRepr) - assert(TypedEncoder[I[I[Int]]].catalystRepr == TypedEncoder[Int].catalystRepr) + assert( + TypedEncoder[I[I[Int]]].catalystRepr == TypedEncoder[Int].catalystRepr + ) assert(TypedEncoder[I[Option[Int]]].nullable) } @@ -176,12 +194,18 @@ class InjectionTests extends TypedDatasetSuite { check(forAll(prop[X2[Person, Person]] _)) check(forAll(prop[Person] _)) - assert(TypedEncoder[Person].catalystRepr == TypedEncoder[(Int, String)].catalystRepr) + assert( + TypedEncoder[Person].catalystRepr == TypedEncoder[ + (Int, String) + ].catalystRepr + ) } test("Resolve ambiguity by importing usingDerivation") { import TypedEncoder.usingDerivation - assert(implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]]) + assert( + implicitly[TypedEncoder[Person]].isInstanceOf[RecordEncoder[Person, _, _]] + ) check(forAll(prop[Person] _)) } @@ -200,7 +224,9 @@ class InjectionTests extends TypedDatasetSuite { check(forAll(prop[X2[Employee, Employee]] _)) check(forAll(prop[Employee] _)) - assert(TypedEncoder[Employee].catalystRepr == TypedEncoder[String].catalystRepr) + assert( + TypedEncoder[Employee].catalystRepr == TypedEncoder[String].catalystRepr + ) } test("TypedEncoder[Maybe] cannot be derived") { @@ -220,7 +246,9 @@ class InjectionTests extends TypedDatasetSuite { check(forAll(prop[X2[Switch, Switch]] _)) check(forAll(prop[Switch] _)) - assert(TypedEncoder[Switch].catalystRepr == TypedEncoder[String].catalystRepr) + assert( + TypedEncoder[Switch].catalystRepr == TypedEncoder[String].catalystRepr + ) } test("Derive encoder for type with data constructors defined as parameterless case classes") { @@ -231,7 +259,9 @@ class InjectionTests extends TypedDatasetSuite { check(forAll(prop[X2[Pixel, Pixel]] _)) check(forAll(prop[Pixel] _)) - assert(TypedEncoder[Pixel].catalystRepr == TypedEncoder[String].catalystRepr) + assert( + TypedEncoder[Pixel].catalystRepr == TypedEncoder[String].catalystRepr + ) } test("Derive encoder for phantom type") { @@ -242,7 +272,11 @@ class InjectionTests extends TypedDatasetSuite { check(forAll(prop[X2[Connection[Int], Connection[Int]]] _)) check(forAll(prop[Connection[Int]] _)) - assert(TypedEncoder[Connection[Int]].catalystRepr == TypedEncoder[String].catalystRepr) + assert( + TypedEncoder[Connection[Int]].catalystRepr == TypedEncoder[ + String + ].catalystRepr + ) } test("Derive encoder for ADT with abstract class as the base type") { @@ -253,26 +287,36 @@ class InjectionTests extends TypedDatasetSuite { check(forAll(prop[X2[Vehicle, Vehicle]] _)) check(forAll(prop[Vehicle] _)) - assert(TypedEncoder[Vehicle].catalystRepr == TypedEncoder[String].catalystRepr) + assert( + TypedEncoder[Vehicle].catalystRepr == TypedEncoder[String].catalystRepr + ) } - test("apply method of derived Injection instance produces the correct string") { + test( + "apply method of derived Injection instance produces the correct string" + ) { import frameless.TypedEncoder.injections._ assert(implicitly[Injection[Employee, String]].apply(Casual) === "Casual") assert(implicitly[Injection[Switch, String]].apply(Switch.On) === "On") assert(implicitly[Injection[Pixel, String]].apply(Blue()) === "Blue") - assert(implicitly[Injection[Connection[Int], String]].apply(Open) === "Open") + assert( + implicitly[Injection[Connection[Int], String]].apply(Open) === "Open" + ) assert(implicitly[Injection[Vehicle, String]].apply(Bike) === "Bike") } - test("invert method of derived Injection instance produces the correct value") { + test( + "invert method of derived Injection instance produces the correct value" + ) { import frameless.TypedEncoder.injections._ assert(implicitly[Injection[Employee, String]].invert("Casual") === Casual) assert(implicitly[Injection[Switch, String]].invert("On") === Switch.On) assert(implicitly[Injection[Pixel, String]].invert("Blue") === Blue()) - assert(implicitly[Injection[Connection[Int], String]].invert("Open") === Open) + assert( + implicitly[Injection[Connection[Int], String]].invert("Open") === Open + ) assert(implicitly[Injection[Vehicle, String]].invert("Bike") === Bike) } diff --git a/dataset/src/test/scala/frameless/JobTests.scala b/dataset/src/test/scala/frameless/JobTests.scala index 9650a020f..8ef20970c 100644 --- a/dataset/src/test/scala/frameless/JobTests.scala +++ b/dataset/src/test/scala/frameless/JobTests.scala @@ -6,13 +6,20 @@ import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks import org.scalatest.freespec.AnyFreeSpec import org.scalatest.matchers.should.Matchers - -class JobTests extends AnyFreeSpec with BeforeAndAfterAll with SparkTesting with ScalaCheckDrivenPropertyChecks with Matchers { +class JobTests + extends AnyFreeSpec + with BeforeAndAfterAll + with SparkTesting + with ScalaCheckDrivenPropertyChecks + with Matchers { "map" - { "identity" in { - def check[T](implicit arb: Arbitrary[T]) = forAll { - t: T => Job(t).map(identity).run() shouldEqual Job(t).run() + def check[T]( + implicit + arb: Arbitrary[T] + ) = forAll { t: T => + Job(t).map(identity).run() shouldEqual Job(t).run() } check[Int] @@ -21,8 +28,8 @@ class JobTests extends AnyFreeSpec with BeforeAndAfterAll with SparkTesting with val f1: Int => Int = _ + 1 val f2: Int => Int = (i: Int) => i * i - "composition" in forAll { - i: Int => Job(i).map(f1).map(f2).run() shouldEqual Job(i).map(f1 andThen f2).run() + "composition" in forAll { i: Int => + Job(i).map(f1).map(f2).run() shouldEqual Job(i).map(f1 andThen f2).run() } } @@ -30,25 +37,26 @@ class JobTests extends AnyFreeSpec with BeforeAndAfterAll with SparkTesting with val f1: Int => Job[Int] = (i: Int) => Job(i + 1) val f2: Int => Job[Int] = (i: Int) => Job(i * i) - "left identity" in forAll { - i: Int => Job(i).flatMap(f1).run() shouldEqual f1(i).run() + "left identity" in forAll { i: Int => + Job(i).flatMap(f1).run() shouldEqual f1(i).run() } - "right identity" in forAll { - i: Int => Job(i).flatMap(i => Job.apply(i)).run() shouldEqual Job(i).run() + "right identity" in forAll { i: Int => + Job(i).flatMap(i => Job.apply(i)).run() shouldEqual Job(i).run() } - "associativity" in forAll { - i: Int => Job(i).flatMap(f1).flatMap(f2).run() shouldEqual Job(i).flatMap(ii => f1(ii).flatMap(f2)).run() + "associativity" in forAll { i: Int => + Job(i).flatMap(f1).flatMap(f2).run() shouldEqual Job(i) + .flatMap(ii => f1(ii).flatMap(f2)) + .run() } } "properties" - { - "read back" in forAll { - (k:String, v: String) => - val scopedKey = "frameless.tests." + k - Job(1).withLocalProperty(scopedKey,v).run() - sc.getLocalProperty(scopedKey) shouldBe v + "read back" in forAll { (k: String, v: String) => + val scopedKey = "frameless.tests." + k + Job(1).withLocalProperty(scopedKey, v).run() + sc.getLocalProperty(scopedKey) shouldBe v } } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/JoinTests.scala b/dataset/src/test/scala/frameless/JoinTests.scala index b34911c4f..a07ce4b35 100644 --- a/dataset/src/test/scala/frameless/JoinTests.scala +++ b/dataset/src/test/scala/frameless/JoinTests.scala @@ -1,20 +1,21 @@ package frameless -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{ StructField, StructType } import org.scalacheck.Prop import org.scalacheck.Prop._ class JoinTests extends TypedDatasetSuite { test("ab.joinCross(ac)") { def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering, - C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](left: List[X2[A, B]], + right: List[X2[A, C]] + ): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinCross(rightDs) + val joinedDs = leftDs.joinCross(rightDs) val joinedData = joinedDs.collect().run().toVector.sorted @@ -25,9 +26,12 @@ class JoinTests extends TypedDatasetSuite { } yield (ab, ac) }.toVector - val equalSchemas = joinedDs.schema ?= StructType(Seq( - StructField("_1", leftDs.schema, nullable = false), - StructField("_2", rightDs.schema, nullable = false))) + val equalSchemas = joinedDs.schema ?= StructType( + Seq( + StructField("_1", leftDs.schema, nullable = false), + StructField("_2", rightDs.schema, nullable = false) + ) + ) (joined.sorted ?= joinedData) && equalSchemas } @@ -37,19 +41,21 @@ class JoinTests extends TypedDatasetSuite { test("ab.joinFull(ac)(ab.a == ac.a)") { def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering, - C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](left: List[X2[A, B]], + right: List[X2[A, C]] + ): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinFull(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = + leftDs.joinFull(rightDs)(leftDs.col('a) === rightDs.col('a)) val joinedData = joinedDs.collect().run().toVector.sorted val rightKeys = right.map(_.a).toSet - val leftKeys = left.map(_.a).toSet + val leftKeys = left.map(_.a).toSet val joined = { for { ab <- left @@ -65,9 +71,12 @@ class JoinTests extends TypedDatasetSuite { } yield (None, Some(ac)) }.toVector - val equalSchemas = joinedDs.schema ?= StructType(Seq( - StructField("_1", leftDs.schema, nullable = true), - StructField("_2", rightDs.schema, nullable = true))) + val equalSchemas = joinedDs.schema ?= StructType( + Seq( + StructField("_1", leftDs.schema, nullable = true), + StructField("_2", rightDs.schema, nullable = true) + ) + ) (joined.sorted ?= joinedData) && equalSchemas } @@ -77,14 +86,16 @@ class JoinTests extends TypedDatasetSuite { test("ab.joinInner(ac)(ab.a == ac.a)") { def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering, - C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](left: List[X2[A, B]], + right: List[X2[A, C]] + ): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinInner(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = + leftDs.joinInner(rightDs)(leftDs.col('a) === rightDs.col('a)) val joinedData = joinedDs.collect().run().toVector.sorted @@ -95,9 +106,12 @@ class JoinTests extends TypedDatasetSuite { } yield (ab, ac) }.toVector - val equalSchemas = joinedDs.schema ?= StructType(Seq( - StructField("_1", leftDs.schema, nullable = false), - StructField("_2", rightDs.schema, nullable = false))) + val equalSchemas = joinedDs.schema ?= StructType( + Seq( + StructField("_1", leftDs.schema, nullable = false), + StructField("_2", rightDs.schema, nullable = false) + ) + ) (joined.sorted ?= joinedData) && equalSchemas } @@ -107,14 +121,16 @@ class JoinTests extends TypedDatasetSuite { test("ab.joinLeft(ac)(ab.a == ac.a)") { def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering, - C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](left: List[X2[A, B]], + right: List[X2[A, C]] + ): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinLeft(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = + leftDs.joinLeft(rightDs)(leftDs.col('a) === rightDs.col('a)) val joinedData = joinedDs.collect().run().toVector.sorted @@ -130,11 +146,16 @@ class JoinTests extends TypedDatasetSuite { } yield (ab, None) }.toVector - val equalSchemas = joinedDs.schema ?= StructType(Seq( - StructField("_1", leftDs.schema, nullable = false), - StructField("_2", rightDs.schema, nullable = true))) + val equalSchemas = joinedDs.schema ?= StructType( + Seq( + StructField("_1", leftDs.schema, nullable = false), + StructField("_2", rightDs.schema, nullable = true) + ) + ) - (joined.sorted ?= joinedData) && (joinedData.map(_._1).toSet ?= left.toSet) && equalSchemas + (joined.sorted ?= joinedData) && (joinedData + .map(_._1) + .toSet ?= left.toSet) && equalSchemas } check(forAll(prop[Int, Long, String] _)) @@ -142,15 +163,17 @@ class JoinTests extends TypedDatasetSuite { test("ab.joinLeftAnti(ac)(ab.a == ac.a)") { def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering, - C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](left: List[X2[A, B]], + right: List[X2[A, C]] + ): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) val rightKeys = right.map(_.a).toSet - val joinedDs = leftDs - .joinLeftAnti(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = + leftDs.joinLeftAnti(rightDs)(leftDs.col('a) === rightDs.col('a)) val joinedData = joinedDs.collect().run().toVector.sorted @@ -170,15 +193,17 @@ class JoinTests extends TypedDatasetSuite { test("ab.joinLeftSemi(ac)(ab.a == ac.a)") { def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering, - C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](left: List[X2[A, B]], + right: List[X2[A, C]] + ): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) val rightKeys = right.map(_.a).toSet - val joinedDs = leftDs - .joinLeftSemi(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = + leftDs.joinLeftSemi(rightDs)(leftDs.col('a) === rightDs.col('a)) val joinedData = joinedDs.collect().run().toVector.sorted @@ -198,14 +223,16 @@ class JoinTests extends TypedDatasetSuite { test("ab.joinRight(ac)(ab.a == ac.a)") { def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering, - C : TypedEncoder : Ordering - ](left: List[X2[A, B]], right: List[X2[A, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](left: List[X2[A, B]], + right: List[X2[A, C]] + ): Prop = { val leftDs = TypedDataset.create(left) val rightDs = TypedDataset.create(right) - val joinedDs = leftDs - .joinRight(rightDs)(leftDs.col('a) === rightDs.col('a)) + val joinedDs = + leftDs.joinRight(rightDs)(leftDs.col('a) === rightDs.col('a)) val joinedData = joinedDs.collect().run().toVector.sorted @@ -221,11 +248,16 @@ class JoinTests extends TypedDatasetSuite { } yield (None, ac) }.toVector - val equalSchemas = joinedDs.schema ?= StructType(Seq( - StructField("_1", leftDs.schema, nullable = true), - StructField("_2", rightDs.schema, nullable = false))) + val equalSchemas = joinedDs.schema ?= StructType( + Seq( + StructField("_1", leftDs.schema, nullable = true), + StructField("_2", rightDs.schema, nullable = false) + ) + ) - (joined.sorted ?= joinedData) && (joinedData.map(_._2).toSet ?= right.toSet) && equalSchemas + (joined.sorted ?= joinedData) && (joinedData + .map(_._2) + .toSet ?= right.toSet) && equalSchemas } check(forAll(prop[Int, Long, String] _)) diff --git a/dataset/src/test/scala/frameless/LitTests.scala b/dataset/src/test/scala/frameless/LitTests.scala index 50df45220..dd1d282d3 100644 --- a/dataset/src/test/scala/frameless/LitTests.scala +++ b/dataset/src/test/scala/frameless/LitTests.scala @@ -9,22 +9,26 @@ import org.scalacheck.Prop, Prop._ import RecordEncoderTests.Name class LitTests extends TypedDatasetSuite with Matchers { - def prop[A: TypedEncoder](value: A)(implicit i0: shapeless.Refute[IsValueClass[A]]): Prop = { + + def prop[A: TypedEncoder]( + value: A + )(implicit + i0: shapeless.Refute[IsValueClass[A]] + ): Prop = { val df: TypedDataset[Int] = TypedDataset.create(1 :: Nil) val l: TypedColumn[Int, A] = lit(value) // filter forces whole codegen - val elems = df.deserialized.filter((_:Int) => true).select(l) + val elems = df.deserialized + .filter((_: Int) => true) + .select(l) .collect() .run() .toVector // otherwise it uses local relation - val localElems = df.select(l) - .collect() - .run() - .toVector + val localElems = df.select(l).collect().run().toVector val expected = Vector(value) @@ -56,23 +60,24 @@ class LitTests extends TypedDatasetSuite with Matchers { } test("support value class") { - val initial = Seq( - Q(name = new Name("Foo"), id = 1), - Q(name = new Name("Bar"), id = 2)) + val initial = + Seq(Q(name = new Name("Foo"), id = 1), Q(name = new Name("Bar"), id = 2)) val ds = TypedDataset.create(initial) ds.collect.run() shouldBe initial val lorem = new Name("Lorem") - ds.withColumnReplaced('name, functions.litValue(lorem)). - collect.run() shouldBe initial.map(_.copy(name = lorem)) + ds.withColumnReplaced('name, functions.litValue(lorem)) + .collect + .run() shouldBe initial.map(_.copy(name = lorem)) } test("support optional value class") { val initial = Seq( R(name = "Foo", id = 1, alias = None), - R(name = "Bar", id = 2, alias = Some(new Name("Lorem")))) + R(name = "Bar", id = 2, alias = Some(new Name("Lorem"))) + ) val ds = TypedDataset.create(initial) ds.collect.run() shouldBe initial @@ -82,13 +87,13 @@ class LitTests extends TypedDatasetSuite with Matchers { val lit = functions.litValue(someIpsum) val tds = ds.withColumnReplaced('alias, functions.litValue(someIpsum)) - tds.queryExecution.toString() should include (lit.toString) + tds.queryExecution.toString() should include(lit.toString) - tds. - collect.run() shouldBe initial.map(_.copy(alias = someIpsum)) + tds.collect.run() shouldBe initial.map(_.copy(alias = someIpsum)) - ds.withColumnReplaced('alias, functions.litValue(Option.empty[Name])). - collect.run() shouldBe initial.map(_.copy(alias = None)) + ds.withColumnReplaced('alias, functions.litValue(Option.empty[Name])) + .collect + .run() shouldBe initial.map(_.copy(alias = None)) } test("#205: comparing literals encoded using Injection") { diff --git a/dataset/src/test/scala/frameless/NumericTests.scala b/dataset/src/test/scala/frameless/NumericTests.scala index 0c13ae5a3..7faf8d7b4 100644 --- a/dataset/src/test/scala/frameless/NumericTests.scala +++ b/dataset/src/test/scala/frameless/NumericTests.scala @@ -1,7 +1,7 @@ package frameless import org.apache.spark.sql.Encoder -import org.scalacheck.{Arbitrary, Gen, Prop} +import org.scalacheck.{ Arbitrary, Gen, Prop } import org.scalacheck.Prop._ import org.scalatest.matchers.should.Matchers @@ -43,7 +43,10 @@ class NumericTests extends TypedDatasetSuite with Matchers { } test("multiply") { - def prop[A: TypedEncoder : CatalystNumeric : Numeric : ClassTag](a: A, b: A): Prop = { + def prop[A: TypedEncoder: CatalystNumeric: Numeric: ClassTag]( + a: A, + b: A + ): Prop = { val df = TypedDataset.create(X2(a, b) :: Nil) val result = implicitly[Numeric[A]].times(a, b) val got = df.select(df.col('a) * df.col('b)).collect().run() @@ -59,27 +62,36 @@ class NumericTests extends TypedDatasetSuite with Matchers { } test("divide") { - def prop[A: TypedEncoder: CatalystNumeric: Numeric](a: A, b: A)(implicit cd: CatalystDivisible[A, Double]): Prop = { + def prop[A: TypedEncoder: CatalystNumeric: Numeric]( + a: A, + b: A + )(implicit + cd: CatalystDivisible[A, Double] + ): Prop = { val df = TypedDataset.create(X2(a, b) :: Nil) - if (b == 0) proved else { - val div: Double = implicitly[Numeric[A]].toDouble(a) / implicitly[Numeric[A]].toDouble(b) - val got: Seq[Double] = df.select(df.col('a) / df.col('b)).collect().run() + if (b == 0) proved + else { + val div: Double = implicitly[Numeric[A]] + .toDouble(a) / implicitly[Numeric[A]].toDouble(b) + val got: Seq[Double] = + df.select(df.col('a) / df.col('b)).collect().run() got ?= (div :: Nil) } } - check(prop[Byte ] _) + check(prop[Byte] _) check(prop[Double] _) - check(prop[Int ] _) - check(prop[Long ] _) - check(prop[Short ] _) + check(prop[Int] _) + check(prop[Long] _) + check(prop[Short] _) } test("divide BigDecimals") { def prop(a: BigDecimal, b: BigDecimal): Prop = { val df = TypedDataset.create(X2(a, b) :: Nil) - if (b.doubleValue == 0) proved else { + if (b.doubleValue == 0) proved + else { // Spark performs something in between Double division and BigDecimal division, // we approximate it using double vision and `approximatelyEqual`: val div = BigDecimal(a.doubleValue / b.doubleValue) @@ -107,24 +119,31 @@ class NumericTests extends TypedDatasetSuite with Matchers { } object NumericMod { + implicit val byteInstance = new NumericMod[Byte] { def mod(a: Byte, b: Byte) = (a % b).toByte } + implicit val doubleInstance = new NumericMod[Double] { def mod(a: Double, b: Double) = a % b } + implicit val floatInstance = new NumericMod[Float] { def mod(a: Float, b: Float) = a % b } + implicit val intInstance = new NumericMod[Int] { def mod(a: Int, b: Int) = a % b } + implicit val longInstance = new NumericMod[Long] { def mod(a: Long, b: Long) = a % b } + implicit val shortInstance = new NumericMod[Short] { def mod(a: Short, b: Short) = (a % b).toShort } + implicit val bigDecimalInstance = new NumericMod[BigDecimal] { def mod(a: BigDecimal, b: BigDecimal) = a % b } @@ -133,9 +152,10 @@ class NumericTests extends TypedDatasetSuite with Matchers { test("mod") { import NumericMod._ - def prop[A: TypedEncoder : CatalystNumeric : NumericMod](a: A, b: A): Prop = { + def prop[A: TypedEncoder: CatalystNumeric: NumericMod](a: A, b: A): Prop = { val df = TypedDataset.create(X2(a, b) :: Nil) - if (b == 0) proved else { + if (b == 0) proved + else { val mod: A = implicitly[NumericMod[A]].mod(a, b) val got: Seq[A] = df.select(df.col('a) % df.col('b)).collect().run() @@ -145,19 +165,23 @@ class NumericTests extends TypedDatasetSuite with Matchers { check(prop[Byte] _) check(prop[Double] _) - check(prop[Int ] _) - check(prop[Long ] _) - check(prop[Short ] _) + check(prop[Int] _) + check(prop[Long] _) + check(prop[Short] _) check(prop[BigDecimal] _) } - test("a mod lit(b)"){ + test("a mod lit(b)") { import NumericMod._ - def prop[A: TypedEncoder : CatalystNumeric : NumericMod](elem: A, data: X1[A]): Prop = { + def prop[A: TypedEncoder: CatalystNumeric: NumericMod]( + elem: A, + data: X1[A] + ): Prop = { val dataset = TypedDataset.create(Seq(data)) val a = dataset.col('a) - if (elem == 0) proved else { + if (elem == 0) proved + else { val mod: A = implicitly[NumericMod[A]].mod(data.a, elem) val got: Seq[A] = dataset.select(a % elem).collect().run() @@ -167,9 +191,9 @@ class NumericTests extends TypedDatasetSuite with Matchers { check(prop[Byte] _) check(prop[Double] _) - check(prop[Int ] _) - check(prop[Long ] _) - check(prop[Short ] _) + check(prop[Int] _) + check(prop[Long] _) + check(prop[Short] _) check(prop[BigDecimal] _) } @@ -180,12 +204,13 @@ class NumericTests extends TypedDatasetSuite with Matchers { implicit val doubleWithNaN = Arbitrary { implicitly[Arbitrary[Double]].arbitrary.flatMap(Gen.oneOf(_, Double.NaN)) } - implicit val x1 = Arbitrary{ doubleWithNaN.arbitrary.map(X1(_)) } + implicit val x1 = Arbitrary { doubleWithNaN.arbitrary.map(X1(_)) } - def prop[A : TypedEncoder : Encoder : CatalystNaN](data: List[X1[A]]): Prop = { + def prop[A: TypedEncoder: Encoder: CatalystNaN](data: List[X1[A]]): Prop = { val ds = TypedDataset.create(data) - val expected = ds.toDF().filter(!$"a".isNaN).map(_.getAs[A](0)).collect().toSeq + val expected = + ds.toDF().filter(!$"a".isNaN).map(_.getAs[A](0)).collect().toSeq val rs = ds.filter(!ds('a).isNaN).collect().run().map(_.a) rs ?= expected diff --git a/dataset/src/test/scala/frameless/OrderByTests.scala b/dataset/src/test/scala/frameless/OrderByTests.scala index 98bd7442d..2b07fa860 100644 --- a/dataset/src/test/scala/frameless/OrderByTests.scala +++ b/dataset/src/test/scala/frameless/OrderByTests.scala @@ -7,19 +7,28 @@ import org.apache.spark.sql.Column import org.scalatest.matchers.should.Matchers class OrderByTests extends TypedDatasetSuite with Matchers { - def sortings[A : CatalystOrdered, T]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] = Seq( - (_.desc, _.desc), - (_.asc, _.asc), - (t => t, t => t) //default ascending - ) + + def sortings[ + A: CatalystOrdered, + T + ]: Seq[(TypedColumn[T, A] => SortedTypedColumn[T, A], Column => Column)] = + Seq( + (_.desc, _.desc), + (_.asc, _.asc), + (t => t, t => t) // default ascending + ) test("single column non nullable orderBy") { - def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered](data: Vector[X1[A]]): Prop = { val ds = TypedDataset.create(data) - sortings[A, X1[A]].map { case (typ, untyp) => - ds.dataset.orderBy(untyp(ds.dataset.col("a"))).collect().toVector.?=( - ds.orderBy(typ(ds('a))).collect().run().toVector) + sortings[A, X1[A]].map { + case (typ, untyp) => + ds.dataset + .orderBy(untyp(ds.dataset.col("a"))) + .collect() + .toVector + .?=(ds.orderBy(typ(ds('a))).collect().run().toVector) }.reduce(_ && _) } @@ -36,12 +45,16 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("single column non nullable partition sorting") { - def prop[A: TypedEncoder : CatalystOrdered](data: Vector[X1[A]]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered](data: Vector[X1[A]]): Prop = { val ds = TypedDataset.create(data) - sortings[A, X1[A]].map { case (typ, untyp) => - ds.dataset.sortWithinPartitions(untyp(ds.dataset.col("a"))).collect().toVector.?=( - ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector) + sortings[A, X1[A]].map { + case (typ, untyp) => + ds.dataset + .sortWithinPartitions(untyp(ds.dataset.col("a"))) + .collect() + .toVector + .?=(ds.sortWithinPartitions(typ(ds('a))).collect().run().toVector) }.reduce(_ && _) } @@ -58,15 +71,34 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("two columns non nullable orderBy") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X2[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) => - val vanillaSpark = ds.dataset.orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector - vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&( - vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b))).collect().run().toVector - ) - }.reduce(_ && _) + sortings[A, X2[A, B]].reverse + .zip(sortings[B, X2[A, B]]) + .map { + case ((typA, untypA), (typB, untypB)) => + val vanillaSpark = ds.dataset + .orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))) + .collect() + .toVector + vanillaSpark + .?=( + ds.orderBy(typA(ds('a)), typB(ds('b))).collect().run().toVector + ) + .&&( + vanillaSpark ?= ds + .orderByMany(typA(ds('a)), typB(ds('b))) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -75,15 +107,40 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("two columns non nullable partition sorting") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A,B]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X2[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[A, X2[A, B]].reverse.zip(sortings[B, X2[A, B]]).map { case ((typA, untypA), (typB, untypB)) => - val vanillaSpark = ds.dataset.sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b"))).collect().toVector - vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b))).collect().run().toVector).&&( - vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b))).collect().run().toVector - ) - }.reduce(_ && _) + sortings[A, X2[A, B]].reverse + .zip(sortings[B, X2[A, B]]) + .map { + case ((typA, untypA), (typB, untypB)) => + val vanillaSpark = ds.dataset + .sortWithinPartitions( + untypA(ds.dataset.col("a")), + untypB(ds.dataset.col("b")) + ) + .collect() + .toVector + vanillaSpark + .?=( + ds.sortWithinPartitions(typA(ds('a)), typB(ds('b))) + .collect() + .run() + .toVector + ) + .&&( + vanillaSpark ?= ds + .sortWithinPartitionsMany(typA(ds('a)), typB(ds('b))) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -92,21 +149,43 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("three columns non nullable orderBy") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X3[A, B, A]] + ): Prop = { val ds = TypedDataset.create(data) sortings[A, X3[A, B, A]].reverse .zip(sortings[B, X3[A, B, A]]) .zip(sortings[A, X3[A, B, A]]) - .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => - val vanillaSpark = ds.dataset - .orderBy(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c"))) - .collect().toVector - - vanillaSpark.?=(ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&( - vanillaSpark ?= ds.orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector - ) - }.reduce(_ && _) + .map { + case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => + val vanillaSpark = ds.dataset + .orderBy( + untypA(ds.dataset.col("a")), + untypB(ds.dataset.col("b")), + untypA2(ds.dataset.col("c")) + ) + .collect() + .toVector + + vanillaSpark + .?=( + ds.orderBy(typA(ds('a)), typB(ds('b)), typA2(ds('c))) + .collect() + .run() + .toVector + ) + .&&( + vanillaSpark ?= ds + .orderByMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -115,21 +194,50 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("three columns non nullable partition sorting") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X3[A,B,A]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X3[A, B, A]] + ): Prop = { val ds = TypedDataset.create(data) sortings[A, X3[A, B, A]].reverse .zip(sortings[B, X3[A, B, A]]) .zip(sortings[A, X3[A, B, A]]) - .map { case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => - val vanillaSpark = ds.dataset - .sortWithinPartitions(untypA(ds.dataset.col("a")), untypB(ds.dataset.col("b")), untypA2(ds.dataset.col("c"))) - .collect().toVector - - vanillaSpark.?=(ds.sortWithinPartitions(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector).&&( - vanillaSpark ?= ds.sortWithinPartitionsMany(typA(ds('a)), typB(ds('b)), typA2(ds('c))).collect().run().toVector - ) - }.reduce(_ && _) + .map { + case (((typA, untypA), (typB, untypB)), (typA2, untypA2)) => + val vanillaSpark = ds.dataset + .sortWithinPartitions( + untypA(ds.dataset.col("a")), + untypB(ds.dataset.col("b")), + untypA2(ds.dataset.col("c")) + ) + .collect() + .toVector + + vanillaSpark + .?=( + ds.sortWithinPartitions( + typA(ds('a)), + typB(ds('b)), + typA2(ds('c)) + ).collect() + .run() + .toVector + ) + .&&( + vanillaSpark ?= ds + .sortWithinPartitionsMany( + typA(ds('a)), + typB(ds('b)), + typA2(ds('c)) + ) + .collect() + .run() + .toVector + ) + } + .reduce(_ && _) } check(forAll(prop[SQLDate, Long] _)) @@ -138,13 +246,28 @@ class OrderByTests extends TypedDatasetSuite with Matchers { } test("sort support for mixed default and explicit ordering") { - def prop[A: TypedEncoder : CatalystOrdered, B: TypedEncoder : CatalystOrdered](data: Vector[X2[A, B]]): Prop = { + def prop[ + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[X2[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - ds.dataset.orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=( - ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) && - ds.dataset.sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc).collect().toVector.?=( - ds.sortWithinPartitionsMany(ds('a), ds('b).desc).collect().run().toVector) + ds.dataset + .orderBy(ds.dataset.col("a"), ds.dataset.col("b").desc) + .collect() + .toVector + .?=(ds.orderByMany(ds('a), ds('b).desc).collect().run().toVector) && + ds.dataset + .sortWithinPartitions(ds.dataset.col("a"), ds.dataset.col("b").desc) + .collect() + .toVector + .?=( + ds.sortWithinPartitionsMany(ds('a), ds('b).desc) + .collect() + .run() + .toVector + ) } check(forAll(prop[SQLDate, Long] _)) @@ -159,50 +282,67 @@ class OrderByTests extends TypedDatasetSuite with Matchers { illTyped("""d.sortWithinPartitions(d('b).desc)""") } - test("derives a CatalystOrdered for case classes when all fields are comparable") { + test( + "derives a CatalystOrdered for case classes when all fields are comparable" + ) { type T[A, B] = X3[Int, Boolean, X2[A, B]] def prop[ - A: TypedEncoder : CatalystOrdered, - B: TypedEncoder : CatalystOrdered - ](data: Vector[T[A, B]]): Prop = { + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[T[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[X2[A, B], T[A, B]].map { case (typX2, untypX2) => - val vanilla = ds.dataset.orderBy(untypX2(ds.dataset.col("c"))).collect().toVector - val frameless = ds.orderBy(typX2(ds('c))).collect().run.toVector - vanilla ?= frameless + sortings[X2[A, B], T[A, B]].map { + case (typX2, untypX2) => + val vanilla = + ds.dataset.orderBy(untypX2(ds.dataset.col("c"))).collect().toVector + val frameless = ds.orderBy(typX2(ds('c))).collect().run.toVector + vanilla ?= frameless }.reduce(_ && _) } check(forAll(prop[Int, Long] _)) check(forAll(prop[(String, SQLDate), Float] _)) // Check that nested case classes are properly derived too - check(forAll(prop[X2[Boolean, Float], X4[SQLTimestamp, Double, Short, Byte]] _)) + check( + forAll(prop[X2[Boolean, Float], X4[SQLTimestamp, Double, Short, Byte]] _) + ) } test("derives a CatalystOrdered for tuples when all fields are comparable") { type T[A, B] = X2[Int, (A, B)] def prop[ - A: TypedEncoder : CatalystOrdered, - B: TypedEncoder : CatalystOrdered - ](data: Vector[T[A, B]]): Prop = { + A: TypedEncoder: CatalystOrdered, + B: TypedEncoder: CatalystOrdered + ](data: Vector[T[A, B]] + ): Prop = { val ds = TypedDataset.create(data) - sortings[(A, B), T[A, B]].map { case (typX2, untypX2) => - val vanilla = ds.dataset.orderBy(untypX2(ds.dataset.col("b"))).collect().toVector - val frameless = ds.orderBy(typX2(ds('b))).collect().run.toVector - vanilla ?= frameless + sortings[(A, B), T[A, B]].map { + case (typX2, untypX2) => + val vanilla = + ds.dataset.orderBy(untypX2(ds.dataset.col("b"))).collect().toVector + val frameless = ds.orderBy(typX2(ds('b))).collect().run.toVector + vanilla ?= frameless }.reduce(_ && _) } check(forAll(prop[Int, Long] _)) check(forAll(prop[(String, SQLDate), Float] _)) - check(forAll(prop[X2[Boolean, Float], X1[(SQLTimestamp, Double, Short, Byte)]] _)) + check( + forAll( + prop[X2[Boolean, Float], X1[(SQLTimestamp, Double, Short, Byte)]] _ + ) + ) } test("fails to compile when one of the field isn't comparable") { type T = X2[Int, X2[Int, Map[String, String]]] val d = TypedDataset.create(X2(1, X2(2, Map("not" -> "comparable"))) :: Nil) - illTyped("d.orderBy(d('b).desc)", """Cannot compare columns of type frameless.X2\[Int,scala.collection.immutable.Map\[String,String]].""") + illTyped( + "d.orderBy(d('b).desc)", + """Cannot compare columns of type frameless.X2\[Int,scala.collection.immutable.Map\[String,String]].""" + ) } } diff --git a/dataset/src/test/scala/frameless/RecordEncoderTests.scala b/dataset/src/test/scala/frameless/RecordEncoderTests.scala index 98274cf01..121785371 100644 --- a/dataset/src/test/scala/frameless/RecordEncoderTests.scala +++ b/dataset/src/test/scala/frameless/RecordEncoderTests.scala @@ -1,6 +1,6 @@ package frameless -import org.apache.spark.sql.{Row, functions => F} +import org.apache.spark.sql.{ Row, functions => F } import org.apache.spark.sql.types.{ ArrayType, BinaryType, @@ -14,7 +14,7 @@ import org.apache.spark.sql.types.{ StructType } -import shapeless.{HList, LabelledGeneric} +import shapeless.{ HList, LabelledGeneric } import shapeless.test.illTyped import org.scalatest.matchers.should.Matchers @@ -25,19 +25,31 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { } test("Dropping fields") { - def dropUnitValues[L <: HList](l: L)(implicit d: DropUnitValues[L]): d.Out = d(l) - val fields = LabelledGeneric[TupleWithUnits].to(TupleWithUnits(42, "something")) - dropUnitValues(fields) shouldEqual LabelledGeneric[(Int, String)].to((42, "something")) + def dropUnitValues[L <: HList]( + l: L + )(implicit + d: DropUnitValues[L] + ): d.Out = d(l) + val fields = + LabelledGeneric[TupleWithUnits].to(TupleWithUnits(42, "something")) + dropUnitValues(fields) shouldEqual LabelledGeneric[(Int, String)] + .to((42, "something")) } test("Representation skips units") { - assert(TypedEncoder[(Int, String)].catalystRepr == TypedEncoder[TupleWithUnits].catalystRepr) + assert( + TypedEncoder[(Int, String)].catalystRepr == TypedEncoder[ + TupleWithUnits + ].catalystRepr + ) } test("Serialization skips units") { val df = session.createDataFrame(Seq((1, "one"), (2, "two"))) val ds = df.as[TupleWithUnits](TypedExpressionEncoder[TupleWithUnits]) - val tds = TypedDataset.create(Seq(TupleWithUnits(1, "one"), TupleWithUnits(2, "two"))) + val tds = TypedDataset.create( + Seq(TupleWithUnits(1, "one"), TupleWithUnits(2, "two")) + ) df.collect shouldEqual tds.toDF.collect ds.collect.toSeq shouldEqual tds.collect.run @@ -51,7 +63,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Empty nested record value becomes none on deserialization") { val rdd = sc.parallelize(Seq(Row(null))) - val schema = TypedEncoder[OptionalNesting].catalystRepr.asInstanceOf[StructType] + val schema = + TypedEncoder[OptionalNesting].catalystRepr.asInstanceOf[StructType] val df = session.createDataFrame(rdd, schema) val ds = TypedDataset.createUnsafe(df)(TypedEncoder[OptionalNesting]) @@ -61,7 +74,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { test("Deeply nested optional values have correct deserialization") { val rdd = sc.parallelize(Seq(Row(true, Row(null, null)))) type NestedOptionPair = X2[Boolean, Option[X2[Option[Int], Option[String]]]] - val schema = TypedEncoder[NestedOptionPair].catalystRepr.asInstanceOf[StructType] + val schema = + TypedEncoder[NestedOptionPair].catalystRepr.asInstanceOf[StructType] val df = session.createDataFrame(rdd, schema) val ds = TypedDataset.createUnsafe(df)(TypedEncoder[NestedOptionPair]) ds.firstOption.run.get shouldBe X2(true, Some(X2(None, None))) @@ -95,14 +109,16 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { encoder.jvmRepr shouldBe ObjectType(classOf[Name]) encoder.catalystRepr shouldBe StructType( - Seq(StructField("value", StringType, false))) + Seq(StructField("value", StringType, false)) + ) val sqlContext = session.sqlContext import sqlContext.implicits._ TypedDataset .createUnsafe[Name](Seq("Foo", "Bar").toDF)(encoder) - .collect().run() shouldBe Seq(new Name("Foo"), new Name("Bar")) + .collect() + .run() shouldBe Seq(new Name("Foo"), new Name("Bar")) } @@ -111,7 +127,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { illTyped( // As `Person` is not a Value class - "val _: RecordFieldEncoder[Person] = RecordFieldEncoder.valueClass") + "val _: RecordFieldEncoder[Person] = RecordFieldEncoder.valueClass" + ) val fieldEncoder: RecordFieldEncoder[Name] = RecordFieldEncoder.valueClass @@ -123,24 +140,28 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { encoder.jvmRepr shouldBe ObjectType(classOf[Person]) - val expectedPersonStructType = StructType(Seq( - StructField("name", StringType, false), - StructField("age", IntegerType, false))) + val expectedPersonStructType = StructType( + Seq( + StructField("name", StringType, false), + StructField("age", IntegerType, false) + ) + ) encoder.catalystRepr shouldBe expectedPersonStructType val unsafeDs: TypedDataset[Person] = { - val rdd = sc.parallelize(Seq( - Row.fromTuple("Foo" -> 2), - Row.fromTuple("Bar" -> 3) - )) + val rdd = sc.parallelize( + Seq( + Row.fromTuple("Foo" -> 2), + Row.fromTuple("Bar" -> 3) + ) + ) val df = session.createDataFrame(rdd, expectedPersonStructType) TypedDataset.createUnsafe(df)(encoder) } - val expected = Seq( - Person(new Name("Foo"), 2), Person(new Name("Bar"), 3)) + val expected = Seq(Person(new Name("Foo"), 2), Person(new Name("Bar"), 3)) unsafeDs.collect.run() shouldBe expected @@ -151,8 +172,10 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val lorem = new Name("Lorem") - safeDs.withColumnReplaced('name, functions.litValue(lorem)). - collect.run() shouldBe expected.map(_.copy(name = lorem)) + safeDs + .withColumnReplaced('name, functions.litValue(lorem)) + .collect + .run() shouldBe expected.map(_.copy(name = lorem)) } test("Case class with value class as optional field") { @@ -160,7 +183,8 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { illTyped( // As `Person` is not a Value class """val _: RecordFieldEncoder[Option[Person]] = - RecordFieldEncoder.optionValueClass""") + RecordFieldEncoder.optionValueClass""" + ) val fieldEncoder: RecordFieldEncoder[Option[Name]] = RecordFieldEncoder.optionValueClass @@ -168,33 +192,37 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { fieldEncoder.encoder.catalystRepr shouldBe StringType fieldEncoder.encoder. // !StringType - jvmRepr shouldBe ObjectType(classOf[Option[_]]) + jvmRepr shouldBe ObjectType(classOf[Option[_]]) // Encode as a Person field val encoder = TypedEncoder[User] encoder.jvmRepr shouldBe ObjectType(classOf[User]) - val expectedPersonStructType = StructType(Seq( - StructField("id", LongType, false), - StructField("name", StringType, true))) + val expectedPersonStructType = StructType( + Seq( + StructField("id", LongType, false), + StructField("name", StringType, true) + ) + ) encoder.catalystRepr shouldBe expectedPersonStructType val ds1: TypedDataset[User] = { - val rdd = sc.parallelize(Seq( - Row(1L, null), - Row(2L, "Foo") - )) + val rdd = sc.parallelize( + Seq( + Row(1L, null), + Row(2L, "Foo") + ) + ) val df = session.createDataFrame(rdd, expectedPersonStructType) TypedDataset.createUnsafe(df)(encoder) } - ds1.collect.run() shouldBe Seq( - User(1L, None), - User(2L, Some(new Name("Foo")))) + ds1.collect + .run() shouldBe Seq(User(1L, None), User(2L, Some(new Name("Foo")))) val ds2: TypedDataset[User] = { val sqlContext = session.sqlContext @@ -206,18 +234,18 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { """{"id":5,"name":null}""" ).toDF - val df2 = df1.withColumn( - "jsonValue", - F.from_json(df1.col("value"), expectedPersonStructType)). - select("jsonValue.id", "jsonValue.name") + val df2 = df1 + .withColumn( + "jsonValue", + F.from_json(df1.col("value"), expectedPersonStructType) + ) + .select("jsonValue.id", "jsonValue.name") TypedDataset.createUnsafe[User](df2) } - val expected = Seq( - User(3L, None), - User(4L, Some(new Name("Lorem"))), - User(5L, None)) + val expected = + Seq(User(3L, None), User(4L, Some(new Name("Lorem"))), User(5L, None)) ds2.collect.run() shouldBe expected @@ -232,11 +260,19 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { encoder.jvmRepr shouldBe ObjectType(classOf[D]) - val expectedStructType = StructType(Seq( - StructField("m", MapType( - keyType = StringType, - valueType = IntegerType, - valueContainsNull = false), false))) + val expectedStructType = StructType( + Seq( + StructField( + "m", + MapType( + keyType = StringType, + valueType = IntegerType, + valueContainsNull = false + ), + false + ) + ) + ) encoder.catalystRepr shouldBe expectedStructType @@ -246,18 +282,19 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val ds1 = TypedDataset.createUnsafe[D] { val df = Seq( """{"m":{"pizza":1,"sushi":2}}""", - """{"m":{"red":3,"blue":4}}""", + """{"m":{"red":3,"blue":4}}""" ).toDF df.withColumn( "jsonValue", - F.from_json(df.col("value"), expectedStructType)). - select("jsonValue.*") + F.from_json(df.col("value"), expectedStructType) + ).select("jsonValue.*") } val expected = Seq( D(m = Map("pizza" -> 1, "sushi" -> 2)), - D(m = Map("red" -> 3, "blue" -> 4))) + D(m = Map("red" -> 3, "blue" -> 4)) + ) ds1.collect.run() shouldBe expected @@ -275,12 +312,20 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { encoder.jvmRepr shouldBe ObjectType(classOf[Student]) - val expectedStudentStructType = StructType(Seq( - StructField("name", StringType, false), - StructField("grades", MapType( - keyType = StringType, - valueType = DecimalType.SYSTEM_DEFAULT, - valueContainsNull = false), false))) + val expectedStudentStructType = StructType( + Seq( + StructField("name", StringType, false), + StructField( + "grades", + MapType( + keyType = StringType, + valueType = DecimalType.SYSTEM_DEFAULT, + valueContainsNull = false + ), + false + ) + ) + ) encoder.catalystRepr shouldBe expectedStudentStructType @@ -290,51 +335,65 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val ds1 = TypedDataset.createUnsafe[Student] { val df = Seq( """{"name":"Foo","grades":{"math":1,"physics":"23.4"}}""", - """{"name":"Bar","grades":{"biology":18.5,"geography":4}}""", + """{"name":"Bar","grades":{"biology":18.5,"geography":4}}""" ).toDF df.withColumn( "jsonValue", - F.from_json(df.col("value"), expectedStudentStructType)). - select("jsonValue.*") + F.from_json(df.col("value"), expectedStudentStructType) + ).select("jsonValue.*") } val expected = Seq( - Student(name = "Foo", grades = Map( - new Subject("math") -> new Grade(BigDecimal(1)), - new Subject("physics") -> new Grade(BigDecimal(23.4D)))), - Student(name = "Bar", grades = Map( - new Subject("biology") -> new Grade(BigDecimal(18.5)), - new Subject("geography") -> new Grade(BigDecimal(4L))))) + Student( + name = "Foo", + grades = Map( + new Subject("math") -> new Grade(BigDecimal(1)), + new Subject("physics") -> new Grade(BigDecimal(23.4D)) + ) + ), + Student( + name = "Bar", + grades = Map( + new Subject("biology") -> new Grade(BigDecimal(18.5)), + new Subject("geography") -> new Grade(BigDecimal(4L)) + ) + ) + ) ds1.collect.run() shouldBe expected val grades = Map[Subject, Grade]( - new Subject("any") -> new Grade(BigDecimal(Long.MaxValue) + 1L)) + new Subject("any") -> new Grade(BigDecimal(Long.MaxValue) + 1L) + ) val ds2 = ds1.withColumnReplaced('grades, functions.lit(grades)) - ds2.collect.run() shouldBe Seq( - Student("Foo", grades), Student("Bar", grades)) + ds2.collect + .run() shouldBe Seq(Student("Foo", grades), Student("Bar", grades)) } test("Encode binary array") { val encoder = TypedEncoder[Tuple2[String, Array[Byte]]] - encoder.jvmRepr shouldBe ObjectType( - classOf[Tuple2[String, Array[Byte]]]) + encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[String, Array[Byte]]]) - val expectedStructType = StructType(Seq( - StructField("_1", StringType, false), - StructField("_2", BinaryType, false))) + val expectedStructType = StructType( + Seq( + StructField("_1", StringType, false), + StructField("_2", BinaryType, false) + ) + ) encoder.catalystRepr shouldBe expectedStructType val ds1: TypedDataset[(String, Array[Byte])] = { - val rdd = sc.parallelize(Seq( - Row.fromTuple("Foo" -> Array[Byte](3, 4)), - Row.fromTuple("Bar" -> Array[Byte](5)) - )) + val rdd = sc.parallelize( + Seq( + Row.fromTuple("Foo" -> Array[Byte](3, 4)), + Row.fromTuple("Bar" -> Array[Byte](5)) + ) + ) val df = session.createDataFrame(rdd, expectedStructType) TypedDataset.createUnsafe(df)(encoder) @@ -342,28 +401,27 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val expected = Seq("Foo" -> Seq[Byte](3, 4), "Bar" -> Seq[Byte](5)) - ds1.collect.run().map { - case (_1, _2) => _1 -> _2.toSeq - } shouldBe expected + ds1.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected val subjects = "lorem".getBytes("UTF-8").toSeq val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray)) - ds2.collect.run().map { - case (_1, _2) => _1 -> _2.toSeq - } shouldBe expected.map(_.copy(_2 = subjects)) + ds2.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected + .map(_.copy(_2 = subjects)) } test("Encode simple array") { val encoder = TypedEncoder[Tuple2[String, Array[Int]]] - encoder.jvmRepr shouldBe ObjectType( - classOf[Tuple2[String, Array[Int]]]) + encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[String, Array[Int]]]) - val expectedStructType = StructType(Seq( - StructField("_1", StringType, false), - StructField("_2", ArrayType(IntegerType, false), false))) + val expectedStructType = StructType( + Seq( + StructField("_1", StringType, false), + StructField("_2", ArrayType(IntegerType, false), false) + ) + ) encoder.catalystRepr shouldBe expectedStructType @@ -373,28 +431,25 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val ds1 = TypedDataset.createUnsafe[(String, Array[Int])] { val df = Seq( """{"_1":"Foo", "_2":[3, 4]}""", - """{"_1":"Bar", "_2":[5]}""", + """{"_1":"Bar", "_2":[5]}""" ).toDF df.withColumn( "jsonValue", - F.from_json(df.col("value"), expectedStructType)). - select("jsonValue.*") + F.from_json(df.col("value"), expectedStructType) + ).select("jsonValue.*") } val expected = Seq("Foo" -> Seq(3, 4), "Bar" -> Seq(5)) - ds1.collect.run().map { - case (_1, _2) => _1 -> _2.toSeq - } shouldBe expected + ds1.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected val subjects = Seq(6, 6, 7) val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray)) - ds2.collect.run().map { - case (_1, _2) => _1 -> _2.toSeq - } shouldBe expected.map(_.copy(_2 = subjects)) + ds2.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected + .map(_.copy(_2 = subjects)) } test("Encode array of Value class") { @@ -402,12 +457,14 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val encoder = TypedEncoder[Tuple2[String, Array[Subject]]] - encoder.jvmRepr shouldBe ObjectType( - classOf[Tuple2[String, Array[Subject]]]) + encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[String, Array[Subject]]]) - val expectedStructType = StructType(Seq( - StructField("_1", StringType, false), - StructField("_2", ArrayType(StringType, false), false))) + val expectedStructType = StructType( + Seq( + StructField("_1", StringType, false), + StructField("_2", ArrayType(StringType, false), false) + ) + ) encoder.catalystRepr shouldBe expectedStructType @@ -417,30 +474,28 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val ds1 = TypedDataset.createUnsafe[(String, Array[Subject])] { val df = Seq( """{"_1":"Foo", "_2":["math","physics"]}""", - """{"_1":"Bar", "_2":["biology","geography"]}""", + """{"_1":"Bar", "_2":["biology","geography"]}""" ).toDF df.withColumn( "jsonValue", - F.from_json(df.col("value"), expectedStructType)). - select("jsonValue.*") + F.from_json(df.col("value"), expectedStructType) + ).select("jsonValue.*") } val expected = Seq( "Foo" -> Seq(new Subject("math"), new Subject("physics")), - "Bar" -> Seq(new Subject("biology"), new Subject("geography"))) + "Bar" -> Seq(new Subject("biology"), new Subject("geography")) + ) - ds1.collect.run().map { - case (_1, _2) => _1 -> _2.toSeq - } shouldBe expected + ds1.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected val subjects = Seq(new Subject("lorem"), new Subject("ipsum")) val ds2 = ds1.withColumnReplaced('_2, functions.lit(subjects.toArray)) - ds2.collect.run().map { - case (_1, _2) => _1 -> _2.toSeq - } shouldBe expected.map(_.copy(_2 = subjects)) + ds2.collect.run().map { case (_1, _2) => _1 -> _2.toSeq } shouldBe expected + .map(_.copy(_2 = subjects)) } test("Encode case class with simple Seq") { @@ -450,22 +505,41 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { encoder.jvmRepr shouldBe ObjectType(classOf[B]) - val expectedStructType = StructType(Seq( - StructField("a", ArrayType(StructType(Seq( - StructField("x", IntegerType, false))), false), false))) + val expectedStructType = StructType( + Seq( + StructField( + "a", + ArrayType( + StructType(Seq(StructField("x", IntegerType, false))), + false + ), + false + ) + ) + ) encoder.catalystRepr shouldBe expectedStructType val ds1: TypedDataset[B] = { - val rdd = sc.parallelize(Seq( - Row.fromTuple(Tuple1(Seq( - Row.fromTuple(Tuple1[Int](1)), - Row.fromTuple(Tuple1[Int](3)) - ))), - Row.fromTuple(Tuple1(Seq( - Row.fromTuple(Tuple1[Int](2)) - ))) - )) + val rdd = sc.parallelize( + Seq( + Row.fromTuple( + Tuple1( + Seq( + Row.fromTuple(Tuple1[Int](1)), + Row.fromTuple(Tuple1[Int](3)) + ) + ) + ), + Row.fromTuple( + Tuple1( + Seq( + Row.fromTuple(Tuple1[Int](2)) + ) + ) + ) + ) + ) val df = session.createDataFrame(rdd, expectedStructType) TypedDataset.createUnsafe(df)(encoder) @@ -489,9 +563,12 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { encoder.jvmRepr shouldBe ObjectType(classOf[Tuple2[Int, Seq[Name]]]) - val expectedStructType = StructType(Seq( - StructField("_1", IntegerType, false), - StructField("_2", ArrayType(StringType, false), false))) + val expectedStructType = StructType( + Seq( + StructField("_1", IntegerType, false), + StructField("_2", ArrayType(StringType, false), false) + ) + ) encoder.catalystRepr shouldBe expectedStructType @@ -501,18 +578,19 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { val df = Seq( """{"_1":1, "_2":["foo", "bar"]}""", - """{"_1":2, "_2":["lorem"]}""", + """{"_1":2, "_2":["lorem"]}""" ).toDF df.withColumn( "jsonValue", - F.from_json(df.col("value"), expectedStructType)). - select("jsonValue.*") + F.from_json(df.col("value"), expectedStructType) + ).select("jsonValue.*") } val expected = Seq( 1 -> Seq(new Name("foo"), new Name("bar")), - 2 -> Seq(new Name("lorem"))) + 2 -> Seq(new Name("lorem")) + ) ds1.collect.run() shouldBe expected } @@ -523,9 +601,15 @@ final class RecordEncoderTests extends TypedDatasetSuite with Matchers { case class UnitsOnly(a: Unit, b: Unit) case class TupleWithUnits( - u0: Unit, _1: Int, u1: Unit, u2: Unit, _2: String, u3: Unit) + u0: Unit, + _1: Int, + u1: Unit, + u2: Unit, + _2: String, + u3: Unit) object TupleWithUnits { + def apply(_1: Int, _2: String): TupleWithUnits = TupleWithUnits((), _1, (), (), _2, ()) } diff --git a/dataset/src/test/scala/frameless/SchemaTests.scala b/dataset/src/test/scala/frameless/SchemaTests.scala index 92fd33057..c93c92827 100644 --- a/dataset/src/test/scala/frameless/SchemaTests.scala +++ b/dataset/src/test/scala/frameless/SchemaTests.scala @@ -10,10 +10,13 @@ import org.scalatest.matchers.should.Matchers class SchemaTests extends TypedDatasetSuite with Matchers { def structToNonNullable(struct: StructType): StructType = { - StructType(struct.fields.map( f => f.copy(nullable = false))) + StructType(struct.fields.map(f => f.copy(nullable = false))) } - def prop[A](dataset: TypedDataset[A], ignoreNullable: Boolean = false): Prop = { + def prop[A]( + dataset: TypedDataset[A], + ignoreNullable: Boolean = false + ): Prop = { val schema = dataset.dataset.schema Prop.all( @@ -24,7 +27,9 @@ class SchemaTests extends TypedDatasetSuite with Matchers { if (!ignoreNullable) TypedExpressionEncoder.targetStructType(dataset.encoder) ?= schema else - structToNonNullable(TypedExpressionEncoder.targetStructType(dataset.encoder)) ?= structToNonNullable(schema) + structToNonNullable( + TypedExpressionEncoder.targetStructType(dataset.encoder) + ) ?= structToNonNullable(schema) ) } diff --git a/dataset/src/test/scala/frameless/SelectTests.scala b/dataset/src/test/scala/frameless/SelectTests.scala index 8043fc941..ced762726 100644 --- a/dataset/src/test/scala/frameless/SelectTests.scala +++ b/dataset/src/test/scala/frameless/SelectTests.scala @@ -7,12 +7,13 @@ import scala.reflect.ClassTag class SelectTests extends TypedDatasetSuite { test("select('a) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) @@ -29,14 +30,15 @@ class SelectTests extends TypedDatasetSuite { } test("select('a, 'b) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - eab: TypedEncoder[(A, B)], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + eab: TypedEncoder[(A, B)], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -53,15 +55,16 @@ class SelectTests extends TypedDatasetSuite { } test("select('a, 'b, 'c) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - eab: TypedEncoder[(A, B, C)], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + eab: TypedEncoder[(A, B, C)], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -79,15 +82,16 @@ class SelectTests extends TypedDatasetSuite { } test("select('a,'b,'c,'d) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - ed: TypedEncoder[D], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + ed: TypedEncoder[D], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val a1 = dataset.col[A]('a) val a2 = dataset.col[B]('b) @@ -106,15 +110,16 @@ class SelectTests extends TypedDatasetSuite { } test("select('a,'b,'c,'d,'a) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - ed: TypedEncoder[D], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + ed: TypedEncoder[D], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val a1 = dataset.col[A]('a) val a2 = dataset.col[B]('b) @@ -133,22 +138,24 @@ class SelectTests extends TypedDatasetSuite { } test("select('a,'b,'c,'d,'a, 'c) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - ed: TypedEncoder[D], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + ed: TypedEncoder[D], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val a1 = dataset.col[A]('a) val a2 = dataset.col[B]('b) val a3 = dataset.col[C]('c) val a4 = dataset.col[D]('d) - val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3).collect().run().toVector + val dataset2 = + dataset.select(a1, a2, a3, a4, a1, a3).collect().run().toVector val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c) } dataset2 ?= data2 @@ -160,22 +167,24 @@ class SelectTests extends TypedDatasetSuite { } test("select('a,'b,'c,'d,'a,'c,'b) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - ed: TypedEncoder[D], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + ed: TypedEncoder[D], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val a1 = dataset.col[A]('a) val a2 = dataset.col[B]('b) val a3 = dataset.col[C]('c) val a4 = dataset.col[D]('d) - val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2).collect().run().toVector + val dataset2 = + dataset.select(a1, a2, a3, a4, a1, a3, a2).collect().run().toVector val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b) } dataset2 ?= data2 @@ -187,22 +196,24 @@ class SelectTests extends TypedDatasetSuite { } test("select('a,'b,'c,'d,'a,'c,'b, 'a) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - ed: TypedEncoder[D], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + ed: TypedEncoder[D], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val a1 = dataset.col[A]('a) val a2 = dataset.col[B]('b) val a3 = dataset.col[C]('c) val a4 = dataset.col[D]('d) - val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2, a1).collect().run().toVector + val dataset2 = + dataset.select(a1, a2, a3, a4, a1, a3, a2, a1).collect().run().toVector val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b, a) } dataset2 ?= data2 @@ -214,23 +225,30 @@ class SelectTests extends TypedDatasetSuite { } test("select('a,'b,'c,'d,'a,'c,'b,'a,'c) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - ed: TypedEncoder[D], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + ed: TypedEncoder[D], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val a1 = dataset.col[A]('a) val a2 = dataset.col[B]('b) val a3 = dataset.col[C]('c) val a4 = dataset.col[D]('d) - val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2, a1, a3).collect().run().toVector - val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c) } + val dataset2 = dataset + .select(a1, a2, a3, a4, a1, a3, a2, a1, a3) + .collect() + .run() + .toVector + val data2 = data.map { + case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c) + } dataset2 ?= data2 } @@ -241,23 +259,30 @@ class SelectTests extends TypedDatasetSuite { } test("select('a,'b,'c,'d,'a,'c,'b,'a,'c, 'd) FROM abcd") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - eb: TypedEncoder[B], - ec: TypedEncoder[C], - ed: TypedEncoder[D], - ex4: TypedEncoder[X4[A, B, C, D]], - ca: ClassTag[A] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + eb: TypedEncoder[B], + ec: TypedEncoder[C], + ed: TypedEncoder[D], + ex4: TypedEncoder[X4[A, B, C, D]], + ca: ClassTag[A] + ): Prop = { val dataset = TypedDataset.create(data) val a1 = dataset.col[A]('a) val a2 = dataset.col[B]('b) val a3 = dataset.col[C]('c) val a4 = dataset.col[D]('d) - val dataset2 = dataset.select(a1, a2, a3, a4, a1, a3, a2, a1, a3, a4).collect().run().toVector - val data2 = data.map { case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c, d) } + val dataset2 = dataset + .select(a1, a2, a3, a4, a1, a3, a2, a1, a3, a4) + .collect() + .run() + .toVector + val data2 = data.map { + case X4(a, b, c, d) => (a, b, c, d, a, c, b, a, c, d) + } dataset2 ?= data2 } @@ -268,12 +293,13 @@ class SelectTests extends TypedDatasetSuite { } test("select('a.b)") { - def prop[A, B, C](data: Vector[X2[X2[A, B], C]])( - implicit - eabc: TypedEncoder[X2[X2[A, B], C]], - eb: TypedEncoder[B], - cb: ClassTag[B] - ): Prop = { + def prop[A, B, C]( + data: Vector[X2[X2[A, B], C]] + )(implicit + eabc: TypedEncoder[X2[X2[A, B], C]], + eb: TypedEncoder[B], + cb: ClassTag[B] + ): Prop = { val dataset = TypedDataset.create(data) val AB = dataset.colMany('a, 'b) @@ -287,13 +313,15 @@ class SelectTests extends TypedDatasetSuite { } test("select with column expression addition") { - def prop[A](data: Vector[X1[A]], const: A)( - implicit - eabc: TypedEncoder[X1[A]], - anum: CatalystNumeric[A], - num: Numeric[A], - eb: TypedEncoder[A] - ): Prop = { + def prop[A]( + data: Vector[X1[A]], + const: A + )(implicit + eabc: TypedEncoder[X1[A]], + anum: CatalystNumeric[A], + num: Numeric[A], + eb: TypedEncoder[A] + ): Prop = { val ds = TypedDataset.create(data) val dataset2 = ds.select(ds('a) + const).collect().run().toVector @@ -309,13 +337,15 @@ class SelectTests extends TypedDatasetSuite { } test("select with column expression multiplication") { - def prop[A](data: Vector[X1[A]], const: A)( - implicit - eabc: TypedEncoder[X1[A]], - anum: CatalystNumeric[A], - num: Numeric[A], - eb: TypedEncoder[A] - ): Prop = { + def prop[A]( + data: Vector[X1[A]], + const: A + )(implicit + eabc: TypedEncoder[X1[A]], + anum: CatalystNumeric[A], + num: Numeric[A], + eb: TypedEncoder[A] + ): Prop = { val ds = TypedDataset.create(data) val dataset2 = ds.select(ds('a) * const).collect().run().toVector @@ -331,13 +361,15 @@ class SelectTests extends TypedDatasetSuite { } test("select with column expression subtraction") { - def prop[A](data: Vector[X1[A]], const: A)( - implicit - eabc: TypedEncoder[X1[A]], - cnum: CatalystNumeric[A], - num: Numeric[A], - eb: TypedEncoder[A] - ): Prop = { + def prop[A]( + data: Vector[X1[A]], + const: A + )(implicit + eabc: TypedEncoder[X1[A]], + cnum: CatalystNumeric[A], + num: Numeric[A], + eb: TypedEncoder[A] + ): Prop = { val ds = TypedDataset.create(data) val dataset2 = ds.select(ds('a) - const).collect().run().toVector @@ -352,17 +384,24 @@ class SelectTests extends TypedDatasetSuite { } test("select with column expression division") { - def prop[A](data: Vector[X1[A]], const: A)( - implicit - eabc: TypedEncoder[X1[A]], - anum: CatalystNumeric[A], - frac: Fractional[A], - eb: TypedEncoder[A] - ): Prop = { + def prop[A]( + data: Vector[X1[A]], + const: A + )(implicit + eabc: TypedEncoder[X1[A]], + anum: CatalystNumeric[A], + frac: Fractional[A], + eb: TypedEncoder[A] + ): Prop = { val ds = TypedDataset.create(data) if (const != 0) { - val dataset2 = ds.select(ds('a) / const).collect().run().toVector.asInstanceOf[Vector[A]] + val dataset2 = ds + .select(ds('a) / const) + .collect() + .run() + .toVector + .asInstanceOf[Vector[A]] val data2 = data.map { case X1(a) => frac.div(a, const) } dataset2 ?= data2 } else 0 ?= 0 @@ -377,23 +416,36 @@ class SelectTests extends TypedDatasetSuite { val t: TypedDataset[(Int, Int)] = e.select(e.col('i) * 2, e.col('i)) assert(t.select(t.col('_1)).collect().run().toList === List(2)) // Issue #54 - val fooT = t.select(t.col('_1)).deserialized.map(x => Tuple1.apply(x)).as[Foo] + val fooT = + t.select(t.col('_1)).deserialized.map(x => Tuple1.apply(x)).as[Foo] assert(fooT.select(fooT('i)).collect().run().toList === List(2)) } test("unary - on arithmetic") { - val e = TypedDataset.create[(Int, String, Int)]((1, "a", 2) :: (2, "b", 4) :: (2, "b", 1) :: Nil) + val e = TypedDataset.create[(Int, String, Int)]( + (1, "a", 2) :: (2, "b", 4) :: (2, "b", 1) :: Nil + ) assert(e.select(-e('_1)).collect().run().toVector === Vector(-1, -2, -2)) - assert(e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector(-3, -6, -3)) + assert( + e.select(-(e('_1) + e('_3))).collect().run().toVector === Vector( + -3, + -6, + -3 + ) + ) } test("unary - on strings should not type check") { - val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) + val e = TypedDataset.create[(Int, String, Long)]( + (1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil + ) illTyped("""e.select( -e('_2) )""") } test("select with aggregation operations is not supported") { - val e = TypedDataset.create[(Int, String, Long)]((1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil) + val e = TypedDataset.create[(Int, String, Long)]( + (1, "a", 2L) :: (2, "b", 4L) :: (2, "b", 1L) :: Nil + ) illTyped("""e.select(frameless.functions.aggregate.sum(e('_1)))""") } } diff --git a/dataset/src/test/scala/frameless/SelfJoinTests.scala b/dataset/src/test/scala/frameless/SelfJoinTests.scala index cede7be2a..847a23e0d 100644 --- a/dataset/src/test/scala/frameless/SelfJoinTests.scala +++ b/dataset/src/test/scala/frameless/SelfJoinTests.scala @@ -2,13 +2,18 @@ package frameless import org.scalacheck.Prop import org.scalacheck.Prop._ -import org.apache.spark.sql.{SparkSession, functions => sparkFunctions} +import org.apache.spark.sql.{ SparkSession, functions => sparkFunctions } class SelfJoinTests extends TypedDatasetSuite { + // Without crossJoin.enabled=true Spark doesn't like trivial join conditions: // [error] Join condition is missing or trivial. // [error] Use the CROSS JOIN syntax to allow cartesian products between these relations. - def allowTrivialJoin[T](body: => T)(implicit session: SparkSession): T = { + def allowTrivialJoin[T]( + body: => T + )(implicit + session: SparkSession + ): T = { val crossJoin = "spark.sql.crossJoin.enabled" val oldSetting = session.conf.get(crossJoin) session.conf.set(crossJoin, "true") @@ -17,7 +22,11 @@ class SelfJoinTests extends TypedDatasetSuite { result } - def allowAmbiguousJoin[T](body: => T)(implicit session: SparkSession): T = { + def allowAmbiguousJoin[T]( + body: => T + )(implicit + session: SparkSession + ): T = { val crossJoin = "spark.sql.analyzer.failAmbiguousSelfJoin" val oldSetting = session.conf.get(crossJoin) session.conf.set(crossJoin, "false") @@ -27,22 +36,26 @@ class SelfJoinTests extends TypedDatasetSuite { } test("self join with colLeft/colRight disambiguation") { - def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering - ](dx: List[X2[A, B]], d: X2[A, B]): Prop = allowAmbiguousJoin { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + dx: List[X2[A, B]], + d: X2[A, B] + ): Prop = allowAmbiguousJoin { val data = d :: dx val ds = TypedDataset.create(data) // This is the way to write unambiguous self-join in vanilla, see https://goo.gl/XnkSUD val df1 = ds.dataset.as("df1") val df2 = ds.dataset.as("df2") - val vanilla = df1.join(df2, - sparkFunctions.col("df1.a") === sparkFunctions.col("df2.a")).count() + val vanilla = df1 + .join(df2, sparkFunctions.col("df1.a") === sparkFunctions.col("df2.a")) + .count() - val typed = ds.joinInner(ds)( - ds.colLeft('a) === ds.colRight('a) - ).count().run() + val typed = ds + .joinInner(ds)( + ds.colLeft('a) === ds.colRight('a) + ) + .count() + .run() vanilla ?= typed } @@ -51,47 +64,60 @@ class SelfJoinTests extends TypedDatasetSuite { } test("trivial self join") { - def prop[ - A : TypedEncoder : Ordering, - B : TypedEncoder : Ordering - ](dx: List[X2[A, B]], d: X2[A, B]): Prop = - allowTrivialJoin { allowAmbiguousJoin { - - val data = d :: dx - val ds = TypedDataset.create(data) - val untyped = ds.dataset - // Interestingly, even with aliasing it seems that it's impossible to - // obtain a trivial join condition of shape df1.a == df1.a, Spark we - // always interpret that as df1.a == df2.a. For the purpose of this - // test we fall-back to lit(true) instead. - // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a") - val trivial = sparkFunctions.lit(true) - val vanilla = untyped.as("df1").join(untyped.as("df2"), trivial).count() - - val typed = ds.joinInner(ds)(ds.colLeft('a) === ds.colLeft('a)).count().run - vanilla ?= typed - } } + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + dx: List[X2[A, B]], + d: X2[A, B] + ): Prop = + allowTrivialJoin { + allowAmbiguousJoin { + + val data = d :: dx + val ds = TypedDataset.create(data) + val untyped = ds.dataset + // Interestingly, even with aliasing it seems that it's impossible to + // obtain a trivial join condition of shape df1.a == df1.a, Spark we + // always interpret that as df1.a == df2.a. For the purpose of this + // test we fall-back to lit(true) instead. + // val trivial = sparkFunctions.col("df1.a") === sparkFunctions.col("df1.a") + val trivial = sparkFunctions.lit(true) + val vanilla = + untyped.as("df1").join(untyped.as("df2"), trivial).count() + + val typed = + ds.joinInner(ds)(ds.colLeft('a) === ds.colLeft('a)).count().run + vanilla ?= typed + } + } check(prop[Int, Int] _) } test("self join with unambiguous expression") { def prop[ - A : TypedEncoder : CatalystNumeric : Ordering, - B : TypedEncoder : Ordering - ](data: List[X3[A, A, B]]): Prop = allowAmbiguousJoin { + A: TypedEncoder: CatalystNumeric: Ordering, + B: TypedEncoder: Ordering + ](data: List[X3[A, A, B]] + ): Prop = allowAmbiguousJoin { val ds = TypedDataset.create(data) val df1 = ds.dataset.alias("df1") val df2 = ds.dataset.alias("df2") - val vanilla = df1.join(df2, - (sparkFunctions.col("df1.a") + sparkFunctions.col("df1.b")) === - (sparkFunctions.col("df2.a") + sparkFunctions.col("df2.b"))).count() - - val typed = ds.joinInner(ds)( - (ds.colLeft('a) + ds.colLeft('b)) === (ds.colRight('a) + ds.colRight('b)) - ).count().run() + val vanilla = df1 + .join( + df2, + (sparkFunctions.col("df1.a") + sparkFunctions.col("df1.b")) === + (sparkFunctions.col("df2.a") + sparkFunctions.col("df2.b")) + ) + .count() + + val typed = ds + .joinInner(ds)( + (ds.colLeft('a) + ds.colLeft('b)) === (ds.colRight('a) + ds + .colRight('b)) + ) + .count() + .run() vanilla ?= typed } @@ -99,41 +125,57 @@ class SelfJoinTests extends TypedDatasetSuite { check(prop[Int, Int] _) } - test("Do you want ambiguous self join? This is how you get ambiguous self join.") { + test( + "Do you want ambiguous self join? This is how you get ambiguous self join." + ) { def prop[ - A : TypedEncoder : CatalystNumeric : Ordering, - B : TypedEncoder : Ordering - ](data: List[X3[A, A, B]]): Prop = - allowTrivialJoin { allowAmbiguousJoin { - val ds = TypedDataset.create(data) - - // The point I'm making here is that it "behaves just like Spark". I - // don't know (or really care about how) how Spark disambiguates that - // internally... - val vanilla = ds.dataset.join(ds.dataset, - (ds.dataset("a") + ds.dataset("b")) === - (ds.dataset("a") + ds.dataset("b"))).count() - - val typed = ds.joinInner(ds)( - (ds.col('a) + ds.col('b)) === (ds.col('a) + ds.col('b)) - ).count().run() - - vanilla ?= typed - } } - - check(prop[Int, Int] _) - } + A: TypedEncoder: CatalystNumeric: Ordering, + B: TypedEncoder: Ordering + ](data: List[X3[A, A, B]] + ): Prop = + allowTrivialJoin { + allowAmbiguousJoin { + val ds = TypedDataset.create(data) + + // The point I'm making here is that it "behaves just like Spark". I + // don't know (or really care about how) how Spark disambiguates that + // internally... + val vanilla = ds.dataset + .join( + ds.dataset, + (ds.dataset("a") + ds.dataset("b")) === + (ds.dataset("a") + ds.dataset("b")) + ) + .count() + + val typed = ds + .joinInner(ds)( + (ds.col('a) + ds.col('b)) === (ds.col('a) + ds.col('b)) + ) + .count() + .run() + + vanilla ?= typed + } + } + + check(prop[Int, Int] _) + } test("colLeft and colRight are equivalent to col outside of joins") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - ex4: TypedEncoder[X4[A, B, C, D]] - ): Prop = { + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + ex4: TypedEncoder[X4[A, B, C, D]] + ): Prop = { val dataset = TypedDataset.create(data) - val selectedCol = dataset.select(dataset.col [A]('a)).collect().run().toVector - val selectedColLeft = dataset.select(dataset.colLeft [A]('a)).collect().run().toVector - val selectedColRight = dataset.select(dataset.colRight[A]('a)).collect().run().toVector + val selectedCol = + dataset.select(dataset.col[A]('a)).collect().run().toVector + val selectedColLeft = + dataset.select(dataset.colLeft[A]('a)).collect().run().toVector + val selectedColRight = + dataset.select(dataset.colRight[A]('a)).collect().run().toVector (selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight) } @@ -145,16 +187,26 @@ class SelfJoinTests extends TypedDatasetSuite { } test("colLeft and colRight are equivalent to col outside of joins - via files (codegen)") { - def prop[A, B, C, D](data: Vector[X4[A, B, C, D]])( - implicit - ea: TypedEncoder[A], - ex4: TypedEncoder[X4[A, B, C, D]] - ): Prop = { - TypedDataset.create(data).write.mode("overwrite").parquet("./target/testData") - val dataset = TypedDataset.createUnsafe[X4[A, B, C, D]](session.read.parquet("./target/testData")) - val selectedCol = dataset.select(dataset.col [A]('a)).collect().run().toVector - val selectedColLeft = dataset.select(dataset.colLeft [A]('a)).collect().run().toVector - val selectedColRight = dataset.select(dataset.colRight[A]('a)).collect().run().toVector + def prop[A, B, C, D]( + data: Vector[X4[A, B, C, D]] + )(implicit + ea: TypedEncoder[A], + ex4: TypedEncoder[X4[A, B, C, D]] + ): Prop = { + TypedDataset + .create(data) + .write + .mode("overwrite") + .parquet("./target/testData") + val dataset = TypedDataset.createUnsafe[X4[A, B, C, D]]( + session.read.parquet("./target/testData") + ) + val selectedCol = + dataset.select(dataset.col[A]('a)).collect().run().toVector + val selectedColLeft = + dataset.select(dataset.colLeft[A]('a)).collect().run().toVector + val selectedColRight = + dataset.select(dataset.colRight[A]('a)).collect().run().toVector (selectedCol ?= selectedColLeft) && (selectedCol ?= selectedColRight) } diff --git a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala index 8a4697835..0845b39f8 100644 --- a/dataset/src/test/scala/frameless/TypedDatasetSuite.scala +++ b/dataset/src/test/scala/frameless/TypedDatasetSuite.scala @@ -2,28 +2,35 @@ package frameless import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem import org.apache.hadoop.fs.local.StreamingFS -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SQLContext, SparkSession} +import org.apache.spark.{ SparkConf, SparkContext } +import org.apache.spark.sql.{ SQLContext, SparkSession } import org.scalactic.anyvals.PosZInt import org.scalatest.BeforeAndAfterAll import org.scalatestplus.scalacheck.Checkers import org.scalacheck.Prop import org.scalacheck.Prop._ -import scala.util.{Properties, Try} +import scala.util.{ Properties, Try } import org.scalatest.funsuite.AnyFunSuite trait SparkTesting { self: BeforeAndAfterAll => - val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString + val appID: String = new java.util.Date().toString + math + .floor(math.random * 10e4) + .toLong + .toString /** * Allows bare naked to be used instead of winutils for testing / dev */ def registerFS(sparkConf: SparkConf): SparkConf = { if (System.getProperty("os.name").startsWith("Windows")) - sparkConf.set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName). - set("spark.hadoop.fs.AbstractFileSystem.file.impl", classOf[StreamingFS].getName) + sparkConf + .set("spark.hadoop.fs.file.impl", classOf[BareLocalFileSystem].getName) + .set( + "spark.hadoop.fs.AbstractFileSystem.file.impl", + classOf[StreamingFS].getName + ) else sparkConf } @@ -40,9 +47,9 @@ trait SparkTesting { self: BeforeAndAfterAll => implicit def sc: SparkContext = session.sparkContext implicit def sqlContext: SQLContext = session.sqlContext - def registerOptimizations(sqlContext: SQLContext): Unit = { } + def registerOptimizations(sqlContext: SQLContext): Unit = {} - def addSparkConfigProperties(config: SparkConf): Unit = { } + def addSparkConfigProperties(config: SparkConf): Unit = {} override def beforeAll(): Unit = { assert(s == null) @@ -59,11 +66,16 @@ trait SparkTesting { self: BeforeAndAfterAll => } } +class TypedDatasetSuite + extends AnyFunSuite + with Checkers + with BeforeAndAfterAll + with SparkTesting { -class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll with SparkTesting { // Limit size of generated collections and number of checks to avoid OutOfMemoryError implicit override val generatorDrivenConfig: PropertyCheckConfiguration = { - def getPosZInt(name: String, default: PosZInt) = Properties.envOrNone(s"FRAMELESS_GEN_${name}") + def getPosZInt(name: String, default: PosZInt) = Properties + .envOrNone(s"FRAMELESS_GEN_${name}") .flatMap(s => Try(s.toInt).toOption) .flatMap(PosZInt.from) .getOrElse(default) @@ -75,17 +87,24 @@ class TypedDatasetSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll implicit val sparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob - def approximatelyEqual[A](a: A, b: A)(implicit numeric: Numeric[A]): Prop = { + def approximatelyEqual[A]( + a: A, + b: A + )(implicit + numeric: Numeric[A] + ): Prop = { val da = numeric.toDouble(a) val db = numeric.toDouble(b) - val epsilon = 1E-6 + val epsilon = 1e-6 // Spark has a weird behaviour concerning expressions that should return Inf // Most of the time they return NaN instead, for instance stddev of Seq(-7.827553978923477E227, -5.009124275715786E153) - if((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved + if ((da.isNaN || da.isInfinity) && (db.isNaN || db.isInfinity)) proved else if ( (da - db).abs < epsilon || - (da - db).abs < da.abs / 100) - proved - else falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon." + (da - db).abs < da.abs / 100 + ) + proved + else + falsified :| s"Expected $a but got $b, which is more than 1% off and greater than epsilon = $epsilon." } } diff --git a/dataset/src/test/scala/frameless/UdtEncodedClass.scala b/dataset/src/test/scala/frameless/UdtEncodedClass.scala index 4e5c2c6d9..917e4c1f1 100644 --- a/dataset/src/test/scala/frameless/UdtEncodedClass.scala +++ b/dataset/src/test/scala/frameless/UdtEncodedClass.scala @@ -1,14 +1,19 @@ package frameless import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.expressions.{ + GenericInternalRow, + UnsafeArrayData +} import org.apache.spark.sql.types._ import org.apache.spark.sql.FramelessInternals.UserDefinedType @SQLUserDefinedType(udt = classOf[UdtEncodedClassUdt]) class UdtEncodedClass(val a: Int, val b: Array[Double]) { + override def equals(other: Any): Boolean = other match { - case that: UdtEncodedClass => a == that.a && java.util.Arrays.equals(b, that.b) + case that: UdtEncodedClass => + a == that.a && java.util.Arrays.equals(b, that.b) case _ => false } @@ -25,11 +30,18 @@ object UdtEncodedClass { } class UdtEncodedClassUdt extends UserDefinedType[UdtEncodedClass] { + def sqlType: DataType = { - StructType(Seq( - StructField("a", IntegerType, nullable = false), - StructField("b", ArrayType(DoubleType, containsNull = false), nullable = false) - )) + StructType( + Seq( + StructField("a", IntegerType, nullable = false), + StructField( + "b", + ArrayType(DoubleType, containsNull = false), + nullable = false + ) + ) + ) } def serialize(obj: UdtEncodedClass): InternalRow = { @@ -40,7 +52,8 @@ class UdtEncodedClassUdt extends UserDefinedType[UdtEncodedClass] { } def deserialize(datum: Any): UdtEncodedClass = datum match { - case row: InternalRow => new UdtEncodedClass(row.getInt(0), row.getArray(1).toDoubleArray()) + case row: InternalRow => + new UdtEncodedClass(row.getInt(0), row.getArray(1).toDoubleArray()) } def userClass: Class[UdtEncodedClass] = classOf[UdtEncodedClass] diff --git a/dataset/src/test/scala/frameless/WithColumnTest.scala b/dataset/src/test/scala/frameless/WithColumnTest.scala index c41c4e726..4d62a4f05 100644 --- a/dataset/src/test/scala/frameless/WithColumnTest.scala +++ b/dataset/src/test/scala/frameless/WithColumnTest.scala @@ -8,28 +8,32 @@ class WithColumnTest extends TypedDatasetSuite { import WithColumnTest._ test("fail to compile on missing value") { - val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + val f: TypedDataset[X] = + TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil) illTyped { """val fNew: TypedDataset[XMissing] = f.withColumn[XMissing](f('j) === 10)""" } } test("fail to compile on different column name") { - val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + val f: TypedDataset[X] = + TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil) illTyped { """val fNew: TypedDataset[XDifferentColumnName] = f.withColumn[XDifferentColumnName](f('j) === 10)""" } } test("fail to compile on added column name") { - val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + val f: TypedDataset[X] = + TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil) illTyped { """val fNew: TypedDataset[XAdded] = f.withColumn[XAdded](f('j) === 10)""" } } test("fail to compile on wrong typed column") { - val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil) + val f: TypedDataset[X] = + TypedDataset.create(X(1, 1) :: X(1, 1) :: X(1, 10) :: Nil) illTyped { """val fNew: TypedDataset[XWrongType] = f.withColumn[XWrongType](f('j) === 10)""" } @@ -54,13 +58,10 @@ class WithColumnTest extends TypedDatasetSuite { } test("update in place") { - def prop[A : TypedEncoder](startValue: A, replaceValue: A): Prop = { + def prop[A: TypedEncoder](startValue: A, replaceValue: A): Prop = { val d = TypedDataset.create(X2(startValue, replaceValue) :: Nil) - val X2(a, b) = d.withColumnReplaced('a, d('b)) - .collect() - .run() - .head + val X2(a, b) = d.withColumnReplaced('a, d('b)).collect().run().head a ?= b } diff --git a/dataset/src/test/scala/frameless/XN.scala b/dataset/src/test/scala/frameless/XN.scala index c23d4b45d..a92f7d9d1 100644 --- a/dataset/src/test/scala/frameless/XN.scala +++ b/dataset/src/test/scala/frameless/XN.scala @@ -1,14 +1,18 @@ package frameless -import org.scalacheck.{Arbitrary, Cogen} +import org.scalacheck.{ Arbitrary, Cogen } case class X1[A](a: A) object X1 { + implicit def arbitrary[A: Arbitrary]: Arbitrary[X1[A]] = Arbitrary(implicitly[Arbitrary[A]].arbitrary.map(X1(_))) - implicit def cogen[A](implicit A: Cogen[A]): Cogen[X1[A]] = + implicit def cogen[A]( + implicit + A: Cogen[A] + ): Cogen[X1[A]] = A.contramap(_.a) implicit def ordering[A: Ordering]: Ordering[X1[A]] = Ordering[A].on(_.a) @@ -17,89 +21,225 @@ object X1 { case class X2[A, B](a: A, b: B) object X2 { - implicit def arbitrary[A: Arbitrary, B: Arbitrary]: Arbitrary[X2[A, B]] = - Arbitrary(Arbitrary.arbTuple2[A, B].arbitrary.map((X2.apply[A, B] _).tupled)) - implicit def cogen[A, B](implicit A: Cogen[A], B: Cogen[B]): Cogen[X2[A, B]] = + implicit def arbitrary[A: Arbitrary, B: Arbitrary]: Arbitrary[X2[A, B]] = + Arbitrary( + Arbitrary.arbTuple2[A, B].arbitrary.map((X2.apply[A, B] _).tupled) + ) + + implicit def cogen[A, B]( + implicit + A: Cogen[A], + B: Cogen[B] + ): Cogen[X2[A, B]] = Cogen.tuple2(A, B).contramap(x => (x.a, x.b)) - implicit def ordering[A: Ordering, B: Ordering]: Ordering[X2[A, B]] = Ordering.Tuple2[A, B].on(x => (x.a, x.b)) + implicit def ordering[A: Ordering, B: Ordering]: Ordering[X2[A, B]] = + Ordering.Tuple2[A, B].on(x => (x.a, x.b)) } case class X3[A, B, C](a: A, b: B, c: C) object X3 { - implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3[A, B, C]] = - Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3.apply[A, B, C] _).tupled)) - implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3[A, B, C]] = + implicit def arbitrary[ + A: Arbitrary, + B: Arbitrary, + C: Arbitrary + ]: Arbitrary[X3[A, B, C]] = + Arbitrary( + Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3.apply[A, B, C] _).tupled) + ) + + implicit def cogen[A, B, C]( + implicit + A: Cogen[A], + B: Cogen[B], + C: Cogen[C] + ): Cogen[X3[A, B, C]] = Cogen.tuple3(A, B, C).contramap(x => (x.a, x.b, x.c)) - implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3[A, B, C]] = + implicit def ordering[ + A: Ordering, + B: Ordering, + C: Ordering + ]: Ordering[X3[A, B, C]] = Ordering.Tuple3[A, B, C].on(x => (x.a, x.b, x.c)) } case class X3U[A, B, C](a: A, b: B, u: Unit, c: C) object X3U { - implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3U[A, B, C]] = - Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map(x => X3U[A, B, C](x._1, x._2, (), x._3))) - implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3U[A, B, C]] = + implicit def arbitrary[ + A: Arbitrary, + B: Arbitrary, + C: Arbitrary + ]: Arbitrary[X3U[A, B, C]] = + Arbitrary( + Arbitrary + .arbTuple3[A, B, C] + .arbitrary + .map(x => X3U[A, B, C](x._1, x._2, (), x._3)) + ) + + implicit def cogen[A, B, C]( + implicit + A: Cogen[A], + B: Cogen[B], + C: Cogen[C] + ): Cogen[X3U[A, B, C]] = Cogen.tuple3(A, B, C).contramap(x => (x.a, x.b, x.c)) - implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3U[A, B, C]] = + implicit def ordering[ + A: Ordering, + B: Ordering, + C: Ordering + ]: Ordering[X3U[A, B, C]] = Ordering.Tuple3[A, B, C].on(x => (x.a, x.b, x.c)) } case class X3KV[A, B, C](key: A, value: B, c: C) object X3KV { - implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary]: Arbitrary[X3KV[A, B, C]] = - Arbitrary(Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3KV.apply[A, B, C] _).tupled)) - implicit def cogen[A, B, C](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C]): Cogen[X3KV[A, B, C]] = + implicit def arbitrary[ + A: Arbitrary, + B: Arbitrary, + C: Arbitrary + ]: Arbitrary[X3KV[A, B, C]] = + Arbitrary( + Arbitrary.arbTuple3[A, B, C].arbitrary.map((X3KV.apply[A, B, C] _).tupled) + ) + + implicit def cogen[A, B, C]( + implicit + A: Cogen[A], + B: Cogen[B], + C: Cogen[C] + ): Cogen[X3KV[A, B, C]] = Cogen.tuple3(A, B, C).contramap(x => (x.key, x.value, x.c)) - implicit def ordering[A: Ordering, B: Ordering, C: Ordering]: Ordering[X3KV[A, B, C]] = + implicit def ordering[ + A: Ordering, + B: Ordering, + C: Ordering + ]: Ordering[X3KV[A, B, C]] = Ordering.Tuple3[A, B, C].on(x => (x.key, x.value, x.c)) } case class X4[A, B, C, D](a: A, b: B, c: C, d: D) object X4 { - implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary]: Arbitrary[X4[A, B, C, D]] = - Arbitrary(Arbitrary.arbTuple4[A, B, C, D].arbitrary.map((X4.apply[A, B, C, D] _).tupled)) - implicit def cogen[A, B, C, D](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C], D: Cogen[D]): Cogen[X4[A, B, C, D]] = + implicit def arbitrary[ + A: Arbitrary, + B: Arbitrary, + C: Arbitrary, + D: Arbitrary + ]: Arbitrary[X4[A, B, C, D]] = + Arbitrary( + Arbitrary + .arbTuple4[A, B, C, D] + .arbitrary + .map((X4.apply[A, B, C, D] _).tupled) + ) + + implicit def cogen[A, B, C, D]( + implicit + A: Cogen[A], + B: Cogen[B], + C: Cogen[C], + D: Cogen[D] + ): Cogen[X4[A, B, C, D]] = Cogen.tuple4(A, B, C, D).contramap(x => (x.a, x.b, x.c, x.d)) - implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering]: Ordering[X4[A, B, C, D]] = + implicit def ordering[ + A: Ordering, + B: Ordering, + C: Ordering, + D: Ordering + ]: Ordering[X4[A, B, C, D]] = Ordering.Tuple4[A, B, C, D].on(x => (x.a, x.b, x.c, x.d)) } case class X5[A, B, C, D, E](a: A, b: B, c: C, d: D, e: E) object X5 { - implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary, E: Arbitrary]: Arbitrary[X5[A, B, C, D, E]] = - Arbitrary(Arbitrary.arbTuple5[A, B, C, D, E].arbitrary.map((X5.apply[A, B, C, D, E] _).tupled)) - implicit def cogen[A, B, C, D, E](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C], D: Cogen[D], E: Cogen[E]): Cogen[X5[A, B, C, D, E]] = + implicit def arbitrary[ + A: Arbitrary, + B: Arbitrary, + C: Arbitrary, + D: Arbitrary, + E: Arbitrary + ]: Arbitrary[X5[A, B, C, D, E]] = + Arbitrary( + Arbitrary + .arbTuple5[A, B, C, D, E] + .arbitrary + .map((X5.apply[A, B, C, D, E] _).tupled) + ) + + implicit def cogen[A, B, C, D, E]( + implicit + A: Cogen[A], + B: Cogen[B], + C: Cogen[C], + D: Cogen[D], + E: Cogen[E] + ): Cogen[X5[A, B, C, D, E]] = Cogen.tuple5(A, B, C, D, E).contramap(x => (x.a, x.b, x.c, x.d, x.e)) - implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering]: Ordering[X5[A, B, C, D, E]] = + implicit def ordering[ + A: Ordering, + B: Ordering, + C: Ordering, + D: Ordering, + E: Ordering + ]: Ordering[X5[A, B, C, D, E]] = Ordering.Tuple5[A, B, C, D, E].on(x => (x.a, x.b, x.c, x.d, x.e)) } case class X6[A, B, C, D, E, F](a: A, b: B, c: C, d: D, e: E, f: F) object X6 { - implicit def arbitrary[A: Arbitrary, B: Arbitrary, C: Arbitrary, D: Arbitrary, E: Arbitrary, F: Arbitrary]: Arbitrary[X6[A, B, C, D, E, F]] = - Arbitrary(Arbitrary.arbTuple6[A, B, C, D, E, F].arbitrary.map((X6.apply[A, B, C, D, E, F] _).tupled)) - implicit def cogen[A, B, C, D, E, F](implicit A: Cogen[A], B: Cogen[B], C: Cogen[C], D: Cogen[D], E: Cogen[E], F: Cogen[F]): Cogen[X6[A, B, C, D, E, F]] = - Cogen.tuple6(A, B, C, D, E, F).contramap(x => (x.a, x.b, x.c, x.d, x.e, x.f)) - - implicit def ordering[A: Ordering, B: Ordering, C: Ordering, D: Ordering, E: Ordering, F: Ordering]: Ordering[X6[A, B, C, D, E, F]] = + implicit def arbitrary[ + A: Arbitrary, + B: Arbitrary, + C: Arbitrary, + D: Arbitrary, + E: Arbitrary, + F: Arbitrary + ]: Arbitrary[X6[A, B, C, D, E, F]] = + Arbitrary( + Arbitrary + .arbTuple6[A, B, C, D, E, F] + .arbitrary + .map((X6.apply[A, B, C, D, E, F] _).tupled) + ) + + implicit def cogen[A, B, C, D, E, F]( + implicit + A: Cogen[A], + B: Cogen[B], + C: Cogen[C], + D: Cogen[D], + E: Cogen[E], + F: Cogen[F] + ): Cogen[X6[A, B, C, D, E, F]] = + Cogen + .tuple6(A, B, C, D, E, F) + .contramap(x => (x.a, x.b, x.c, x.d, x.e, x.f)) + + implicit def ordering[ + A: Ordering, + B: Ordering, + C: Ordering, + D: Ordering, + E: Ordering, + F: Ordering + ]: Ordering[X6[A, B, C, D, E, F]] = Ordering.Tuple6[A, B, C, D, E, F].on(x => (x.a, x.b, x.c, x.d, x.e, x.f)) -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/CheckpointTests.scala b/dataset/src/test/scala/frameless/forward/CheckpointTests.scala index 9a1ff8b44..d982f0344 100644 --- a/dataset/src/test/scala/frameless/forward/CheckpointTests.scala +++ b/dataset/src/test/scala/frameless/forward/CheckpointTests.scala @@ -1,8 +1,7 @@ package frameless import org.scalacheck.Prop -import org.scalacheck.Prop.{forAll, _} - +import org.scalacheck.Prop.{ forAll, _ } class CheckpointTests extends TypedDatasetSuite { test("checkpoint") { @@ -18,4 +17,4 @@ class CheckpointTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/ColumnsTests.scala b/dataset/src/test/scala/frameless/forward/ColumnsTests.scala index 282a72c9a..27e20511d 100644 --- a/dataset/src/test/scala/frameless/forward/ColumnsTests.scala +++ b/dataset/src/test/scala/frameless/forward/ColumnsTests.scala @@ -5,7 +5,14 @@ import org.scalacheck.Prop.forAll class ColumnsTests extends TypedDatasetSuite { test("columns") { - def prop(i: Int, s: String, b: Boolean, l: Long, d: Double, by: Byte): Prop = { + def prop( + i: Int, + s: String, + b: Boolean, + l: Long, + d: Double, + by: Byte + ): Prop = { val x1 = X1(i) :: Nil val x2 = X2(i, s) :: Nil val x3 = X3(i, s, b) :: Nil @@ -13,18 +20,21 @@ class ColumnsTests extends TypedDatasetSuite { val x5 = X5(i, s, b, l, d) :: Nil val x6 = X6(i, s, b, l, d, by) :: Nil - val datasets = Seq(TypedDataset.create(x1), TypedDataset.create(x2), - TypedDataset.create(x3), TypedDataset.create(x4), - TypedDataset.create(x5), TypedDataset.create(x6)) + val datasets = Seq( + TypedDataset.create(x1), + TypedDataset.create(x2), + TypedDataset.create(x3), + TypedDataset.create(x4), + TypedDataset.create(x5), + TypedDataset.create(x6) + ) Prop.all(datasets.flatMap { dataset => val columns = dataset.dataset.columns - dataset.columns.map(col => - Prop.propBoolean(columns contains col) - ) + dataset.columns.map(col => Prop.propBoolean(columns contains col)) }: _*) } check(forAll(prop _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/DistinctTests.scala b/dataset/src/test/scala/frameless/forward/DistinctTests.scala index 44da5e59e..bd2b4be41 100644 --- a/dataset/src/test/scala/frameless/forward/DistinctTests.scala +++ b/dataset/src/test/scala/frameless/forward/DistinctTests.scala @@ -7,8 +7,14 @@ import math.Ordering class DistinctTests extends TypedDatasetSuite { test("distinct") { // Comparison done with `.sorted` because order is not preserved by Spark for this operation. - def prop[A: TypedEncoder : Ordering](data: Vector[A]): Prop = - TypedDataset.create(data).distinct.collect().run().toVector.sorted ?= data.distinct.sorted + def prop[A: TypedEncoder: Ordering](data: Vector[A]): Prop = + TypedDataset + .create(data) + .distinct + .collect() + .run() + .toVector + .sorted ?= data.distinct.sorted check(forAll(prop[Int] _)) check(forAll(prop[String] _)) diff --git a/dataset/src/test/scala/frameless/forward/HeadTests.scala b/dataset/src/test/scala/frameless/forward/HeadTests.scala index 63f76e003..c3314fc40 100644 --- a/dataset/src/test/scala/frameless/forward/HeadTests.scala +++ b/dataset/src/test/scala/frameless/forward/HeadTests.scala @@ -1,6 +1,12 @@ package frameless.forward -import frameless.{TypedDataset, TypedDatasetSuite, TypedEncoder, TypedExpressionEncoder, X1} +import frameless.{ + TypedDataset, + TypedDatasetSuite, + TypedEncoder, + TypedExpressionEncoder, + X1 +} import org.apache.spark.sql.SparkSession import org.scalacheck.Prop import org.scalacheck.Prop._ @@ -9,17 +15,25 @@ import scala.reflect.ClassTag import org.scalatest.matchers.should.Matchers class HeadTests extends TypedDatasetSuite with Matchers { - def propArray[A: TypedEncoder : ClassTag : Ordering](data: Vector[X1[A]])(implicit c: SparkSession): Prop = { + + def propArray[A: TypedEncoder: ClassTag: Ordering]( + data: Vector[X1[A]] + )(implicit + c: SparkSession + ): Prop = { import c.implicits._ - if(data.nonEmpty) { - val tds = TypedDataset. - create(c.createDataset(data)( + if (data.nonEmpty) { + val tds = TypedDataset.create( + c.createDataset(data)( TypedExpressionEncoder.apply[X1[A]] - ).orderBy($"a".desc)) - (tds.headOption().run().get ?= data.max). - &&(tds.head(1).run().head ?= data.max). - &&(tds.head(4).run().toVector ?= - data.sortBy(_.a)(implicitly[Ordering[A]].reverse).take(4)) + ).orderBy($"a".desc) + ) + (tds.headOption().run().get ?= data.max) + .&&(tds.head(1).run().head ?= data.max) + .&&( + tds.head(4).run().toVector ?= + data.sortBy(_.a)(implicitly[Ordering[A]].reverse).take(4) + ) } else Prop.passed } diff --git a/dataset/src/test/scala/frameless/forward/InputFilesTests.scala b/dataset/src/test/scala/frameless/forward/InputFilesTests.scala index 246867e63..461457879 100644 --- a/dataset/src/test/scala/frameless/forward/InputFilesTests.scala +++ b/dataset/src/test/scala/frameless/forward/InputFilesTests.scala @@ -14,7 +14,9 @@ class InputFilesTests extends TypedDatasetSuite with Matchers { val filePath = s"$TEST_OUTPUT_DIR/${UUID.randomUUID()}.txt" TypedDataset.create(data).dataset.write.text(filePath) - val dataset = TypedDataset.create(implicitly[SparkSession].sparkContext.textFile(filePath)) + val dataset = TypedDataset.create( + implicitly[SparkSession].sparkContext.textFile(filePath) + ) dataset.inputFiles sameElements dataset.dataset.inputFiles } @@ -25,7 +27,10 @@ class InputFilesTests extends TypedDatasetSuite with Matchers { inputDataset.dataset.write.csv(filePath) val dataset = TypedDataset.createUnsafe( - implicitly[SparkSession].sqlContext.read.schema(inputDataset.schema).csv(filePath)) + implicitly[SparkSession].sqlContext.read + .schema(inputDataset.schema) + .csv(filePath) + ) dataset.inputFiles sameElements dataset.dataset.inputFiles } @@ -36,7 +41,10 @@ class InputFilesTests extends TypedDatasetSuite with Matchers { inputDataset.dataset.write.json(filePath) val dataset = TypedDataset.createUnsafe( - implicitly[SparkSession].sqlContext.read.schema(inputDataset.schema).json(filePath)) + implicitly[SparkSession].sqlContext.read + .schema(inputDataset.schema) + .json(filePath) + ) dataset.inputFiles sameElements dataset.dataset.inputFiles } @@ -45,4 +53,4 @@ class InputFilesTests extends TypedDatasetSuite with Matchers { check(forAll(propCsv[String] _)) check(forAll(propJson[String] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/IntersectTests.scala b/dataset/src/test/scala/frameless/forward/IntersectTests.scala index f0edb856e..396356857 100644 --- a/dataset/src/test/scala/frameless/forward/IntersectTests.scala +++ b/dataset/src/test/scala/frameless/forward/IntersectTests.scala @@ -6,10 +6,14 @@ import math.Ordering class IntersectTests extends TypedDatasetSuite { test("intersect") { - def prop[A: TypedEncoder : Ordering](data1: Vector[A], data2: Vector[A]): Prop = { + def prop[A: TypedEncoder: Ordering]( + data1: Vector[A], + data2: Vector[A] + ): Prop = { val dataset1 = TypedDataset.create(data1) val dataset2 = TypedDataset.create(data2) - val datasetIntersect = dataset1.intersect(dataset2).collect().run().toVector + val datasetIntersect = + dataset1.intersect(dataset2).collect().run().toVector // Vector `intersect` is the multiset intersection, while Spark throws away duplicates. val dataIntersect = data1.intersect(data2).distinct diff --git a/dataset/src/test/scala/frameless/forward/IsLocalTests.scala b/dataset/src/test/scala/frameless/forward/IsLocalTests.scala index f61d25cd1..71fbd27ce 100644 --- a/dataset/src/test/scala/frameless/forward/IsLocalTests.scala +++ b/dataset/src/test/scala/frameless/forward/IsLocalTests.scala @@ -14,4 +14,4 @@ class IsLocalTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala b/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala index dd1874977..b056bc409 100644 --- a/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala +++ b/dataset/src/test/scala/frameless/forward/IsStreamingTests.scala @@ -14,4 +14,4 @@ class IsStreamingTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala b/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala index d59e250df..0a0d82465 100644 --- a/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala +++ b/dataset/src/test/scala/frameless/forward/QueryExecutionTests.scala @@ -1,7 +1,7 @@ package frameless import org.scalacheck.Prop -import org.scalacheck.Prop.{forAll, _} +import org.scalacheck.Prop.{ forAll, _ } class QueryExecutionTests extends TypedDatasetSuite { test("queryExecution") { @@ -14,4 +14,4 @@ class QueryExecutionTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala b/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala index 4cc9a4fde..63ab904f0 100644 --- a/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala +++ b/dataset/src/test/scala/frameless/forward/RandomSplitTests.scala @@ -2,36 +2,45 @@ package frameless import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Prop._ -import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.{ Arbitrary, Gen } import scala.collection.JavaConverters._ import org.scalatest.matchers.should.Matchers class RandomSplitTests extends TypedDatasetSuite with Matchers { - val nonEmptyPositiveArray: Gen[Array[Double]] = Gen.nonEmptyListOf(Gen.posNum[Double]).map(_.toArray) + val nonEmptyPositiveArray: Gen[Array[Double]] = + Gen.nonEmptyListOf(Gen.posNum[Double]).map(_.toArray) test("randomSplit(weight, seed)") { - def prop[A: TypedEncoder : Arbitrary] = forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) { - (data: Vector[A], weights: Array[Double], seed: Long) => - val dataset = TypedDataset.create(data) + def prop[A: TypedEncoder: Arbitrary] = + forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) { + (data: Vector[A], weights: Array[Double], seed: Long) => + val dataset = TypedDataset.create(data) - dataset.randomSplit(weights, seed).map(_.count().run()) sameElements - dataset.dataset.randomSplit(weights, seed).map(_.count()) - } + dataset.randomSplit(weights, seed).map(_.count().run()) sameElements + dataset.dataset.randomSplit(weights, seed).map(_.count()) + } check(prop[Int]) check(prop[String]) } test("randomSplitAsList(weight, seed)") { - def prop[A: TypedEncoder : Arbitrary] = forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) { - (data: Vector[A], weights: Array[Double], seed: Long) => - val dataset = TypedDataset.create(data) - - dataset.randomSplitAsList(weights, seed).asScala.map(_.count().run()) sameElements - dataset.dataset.randomSplitAsList(weights, seed).asScala.map(_.count()) - } + def prop[A: TypedEncoder: Arbitrary] = + forAll(vectorGen[A], nonEmptyPositiveArray, arbitrary[Long]) { + (data: Vector[A], weights: Array[Double], seed: Long) => + val dataset = TypedDataset.create(data) + + dataset + .randomSplitAsList(weights, seed) + .asScala + .map(_.count().run()) sameElements + dataset.dataset + .randomSplitAsList(weights, seed) + .asScala + .map(_.count()) + } check(prop[Int]) check(prop[String]) diff --git a/dataset/src/test/scala/frameless/forward/SQLContextTests.scala b/dataset/src/test/scala/frameless/forward/SQLContextTests.scala index 700f29b05..7b524b076 100644 --- a/dataset/src/test/scala/frameless/forward/SQLContextTests.scala +++ b/dataset/src/test/scala/frameless/forward/SQLContextTests.scala @@ -1,7 +1,7 @@ package frameless import org.scalacheck.Prop -import org.scalacheck.Prop.{forAll, _} +import org.scalacheck.Prop.{ forAll, _ } class SQLContextTests extends TypedDatasetSuite { test("sqlContext") { diff --git a/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala b/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala index c5d0da338..ce3130d3b 100644 --- a/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala +++ b/dataset/src/test/scala/frameless/forward/SparkSessionTests.scala @@ -14,4 +14,4 @@ class SparkSessionTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala b/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala index 3ac93773e..6b9c0dcd0 100644 --- a/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala +++ b/dataset/src/test/scala/frameless/forward/StorageLevelTests.scala @@ -3,27 +3,41 @@ package frameless import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel._ import org.scalacheck.Prop._ -import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.{ Arbitrary, Gen } class StorageLevelTests extends TypedDatasetSuite { - val storageLevelGen: Gen[StorageLevel] = Gen.oneOf(Seq(NONE, DISK_ONLY, DISK_ONLY_2, MEMORY_ONLY, - MEMORY_ONLY_2, MEMORY_ONLY_SER, MEMORY_ONLY_SER_2, MEMORY_AND_DISK, - MEMORY_AND_DISK_2, MEMORY_AND_DISK_SER, MEMORY_AND_DISK_SER_2, OFF_HEAP)) + val storageLevelGen: Gen[StorageLevel] = Gen.oneOf( + Seq( + NONE, + DISK_ONLY, + DISK_ONLY_2, + MEMORY_ONLY, + MEMORY_ONLY_2, + MEMORY_ONLY_SER, + MEMORY_ONLY_SER_2, + MEMORY_AND_DISK, + MEMORY_AND_DISK_2, + MEMORY_AND_DISK_SER, + MEMORY_AND_DISK_SER_2, + OFF_HEAP + ) + ) test("storageLevel") { - def prop[A: TypedEncoder : Arbitrary] = forAll(vectorGen[A], storageLevelGen) { - (data: Vector[A], storageLevel: StorageLevel) => - val dataset = TypedDataset.create(data) - if (storageLevel != StorageLevel.NONE) - dataset.persist(storageLevel) + def prop[A: TypedEncoder: Arbitrary] = + forAll(vectorGen[A], storageLevelGen) { + (data: Vector[A], storageLevel: StorageLevel) => + val dataset = TypedDataset.create(data) + if (storageLevel != StorageLevel.NONE) + dataset.persist(storageLevel) - dataset.count().run() + dataset.count().run() - dataset.storageLevel() ?= dataset.dataset.storageLevel - } + dataset.storageLevel() ?= dataset.dataset.storageLevel + } check(prop[Int]) check(prop[String]) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/TakeTests.scala b/dataset/src/test/scala/frameless/forward/TakeTests.scala index eec77bc80..a7f44a678 100644 --- a/dataset/src/test/scala/frameless/forward/TakeTests.scala +++ b/dataset/src/test/scala/frameless/forward/TakeTests.scala @@ -7,14 +7,22 @@ import scala.reflect.ClassTag class TakeTests extends TypedDatasetSuite { test("take") { def prop[A: TypedEncoder](n: Int, data: Vector[A]): Prop = - (n >= 0) ==> (TypedDataset.create(data).take(n).run().toVector =? data.take(n)) + (n >= 0) ==> (TypedDataset.create(data).take(n).run().toVector =? data + .take(n)) - def propArray[A: TypedEncoder: ClassTag](n: Int, data: Vector[X1[Array[A]]]): Prop = + def propArray[A: TypedEncoder: ClassTag]( + n: Int, + data: Vector[X1[Array[A]]] + ): Prop = (n >= 0) ==> { Prop { - TypedDataset.create(data).take(n).run().toVector.zip(data.take(n)).forall { - case (X1(l), X1(r)) => l sameElements r - } + TypedDataset + .create(data) + .take(n) + .run() + .toVector + .zip(data.take(n)) + .forall { case (X1(l), X1(r)) => l sameElements r } } } diff --git a/dataset/src/test/scala/frameless/forward/ToJSONTests.scala b/dataset/src/test/scala/frameless/forward/ToJSONTests.scala index 5ed79a9c9..5e78ea6d0 100644 --- a/dataset/src/test/scala/frameless/forward/ToJSONTests.scala +++ b/dataset/src/test/scala/frameless/forward/ToJSONTests.scala @@ -14,4 +14,4 @@ class ToJSONTests extends TypedDatasetSuite { check(forAll(prop[Int] _)) check(forAll(prop[String] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala b/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala index faaf25caf..436215008 100644 --- a/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala +++ b/dataset/src/test/scala/frameless/forward/ToLocalIteratorTests.scala @@ -10,7 +10,14 @@ class ToLocalIteratorTests extends TypedDatasetSuite with Matchers { def prop[A: TypedEncoder](data: Vector[A]): Prop = { val dataset = TypedDataset.create(data) - dataset.toLocalIterator().run().asScala.toIterator sameElements dataset.dataset.toLocalIterator().asScala.toIterator + dataset + .toLocalIterator() + .run() + .asScala + .toIterator sameElements dataset.dataset + .toLocalIterator() + .asScala + .toIterator } check(forAll(prop[Int] _)) diff --git a/dataset/src/test/scala/frameless/forward/UnionTests.scala b/dataset/src/test/scala/frameless/forward/UnionTests.scala index 6cd8f4005..b927f24f7 100644 --- a/dataset/src/test/scala/frameless/forward/UnionTests.scala +++ b/dataset/src/test/scala/frameless/forward/UnionTests.scala @@ -30,11 +30,21 @@ class UnionTests extends TypedDatasetSuite { } test("Align fields for case classes") { - def prop[A: TypedEncoder, B: TypedEncoder](data1: Vector[(A, B)], data2: Vector[(A, B)]): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder]( + data1: Vector[(A, B)], + data2: Vector[(A, B)] + ): Prop = { val dataset1 = TypedDataset.create(data1.map((Foo.apply[A, B] _).tupled)) - val dataset2 = TypedDataset.create(data2.map { case (a, b) => Bar[A, B](b, a) }) - val datasetUnion = dataset1.union(dataset2).collect().run().map(foo => (foo.x, foo.y)).toVector + val dataset2 = TypedDataset.create(data2.map { + case (a, b) => Bar[A, B](b, a) + }) + val datasetUnion = dataset1 + .union(dataset2) + .collect() + .run() + .map(foo => (foo.x, foo.y)) + .toVector val dataUnion = data1 union data2 datasetUnion ?= dataUnion @@ -45,11 +55,21 @@ class UnionTests extends TypedDatasetSuite { } test("Align fields for different number of columns") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](data1: Vector[(A, B, C)], data2: Vector[(A, B)]): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data1: Vector[(A, B, C)], + data2: Vector[(A, B)] + ): Prop = { val dataset1 = TypedDataset.create(data2.map((Foo.apply[A, B] _).tupled)) - val dataset2 = TypedDataset.create(data1.map { case (a, b, c) => Baz[A, B, C](c, b, a) }) - val datasetUnion: Seq[(A, B)] = dataset1.union(dataset2).collect().run().map(foo => (foo.x, foo.y)).toVector + val dataset2 = TypedDataset.create(data1.map { + case (a, b, c) => Baz[A, B, C](c, b, a) + }) + val datasetUnion: Seq[(A, B)] = dataset1 + .union(dataset2) + .collect() + .run() + .map(foo => (foo.x, foo.y)) + .toVector val dataUnion = data2 union data1.map { case (a, b, _) => (a, b) } datasetUnion ?= dataUnion @@ -63,4 +83,4 @@ class UnionTests extends TypedDatasetSuite { final case class Foo[A, B](x: A, y: B) final case class Bar[A, B](y: B, x: A) final case class Baz[A, B, C](z: C, y: B, x: A) -final case class Wrong[A, B, C](a: A, b: B, c: C) \ No newline at end of file +final case class Wrong[A, B, C](a: A, b: B, c: C) diff --git a/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala b/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala index 368147c93..462e14e31 100644 --- a/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala +++ b/dataset/src/test/scala/frameless/forward/WriteStreamTests.scala @@ -5,7 +5,7 @@ import java.util.UUID import org.apache.spark.sql.Encoder import org.apache.spark.sql.execution.streaming.MemoryStream import org.scalacheck.Prop._ -import org.scalacheck.{Arbitrary, Gen, Prop} +import org.scalacheck.{ Arbitrary, Gen, Prop } class WriteStreamTests extends TypedDatasetSuite { @@ -36,23 +36,34 @@ class WriteStreamTests extends TypedDatasetSuite { val checkpointPath = s"$TEST_OUTPUT_DIR/checkpoint/$uid" val inputStream = MemoryStream[A] val input = TypedDataset.create(inputStream.toDS()) - val inputter = input.writeStream.format("csv").option("checkpointLocation", s"$checkpointPath/input").start(filePath) + val inputter = input.writeStream + .format("csv") + .option("checkpointLocation", s"$checkpointPath/input") + .start(filePath) inputStream.addData(data) inputter.processAllAvailable() - val dataset = TypedDataset.createUnsafe(sqlContext.readStream.schema(input.schema).csv(filePath)) + val dataset = TypedDataset.createUnsafe( + sqlContext.readStream.schema(input.schema).csv(filePath) + ) - val tester = dataset - .writeStream + val tester = dataset.writeStream .option("checkpointLocation", s"$checkpointPath/tester") .format("memory") .queryName(s"testCsv_$uidNoHyphens") .start() tester.processAllAvailable() val output = spark.table(s"testCsv_$uidNoHyphens").as[A] - TypedDataset.create(data).collect().run().groupBy(identity) ?= output.collect().groupBy(identity).map { case (k, arr) => (k, arr.toSeq) } + TypedDataset.create(data).collect().run().groupBy(identity) ?= output + .collect() + .groupBy(identity) + .map { case (k, arr) => (k, arr.toSeq) } } - check(forAll(Gen.nonEmptyListOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(prop[String])) + check( + forAll(Gen.nonEmptyListOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))( + prop[String] + ) + ) check(forAll(Gen.nonEmptyListOf(Arbitrary.arbitrary[Int]))(prop[Int])) } @@ -66,20 +77,27 @@ class WriteStreamTests extends TypedDatasetSuite { val checkpointPath = s"$TEST_OUTPUT_DIR/checkpoint/$uid" val inputStream = MemoryStream[A] val input = TypedDataset.create(inputStream.toDS()) - val inputter = input.writeStream.format("parquet").option("checkpointLocation", s"$checkpointPath/input").start(filePath) + val inputter = input.writeStream + .format("parquet") + .option("checkpointLocation", s"$checkpointPath/input") + .start(filePath) inputStream.addData(data) inputter.processAllAvailable() - val dataset = TypedDataset.createUnsafe(sqlContext.readStream.schema(input.schema).parquet(filePath)) + val dataset = TypedDataset.createUnsafe( + sqlContext.readStream.schema(input.schema).parquet(filePath) + ) - val tester = dataset - .writeStream + val tester = dataset.writeStream .option("checkpointLocation", s"$checkpointPath/tester") .format("memory") .queryName(s"testParquet_$uidNoHyphens") .start() tester.processAllAvailable() val output = spark.table(s"testParquet_$uidNoHyphens").as[A] - TypedDataset.create(data).collect().run().groupBy(identity) ?= output.collect().groupBy(identity).map { case (k, arr) => (k, arr.toSeq) } + TypedDataset.create(data).collect().run().groupBy(identity) ?= output + .collect() + .groupBy(identity) + .map { case (k, arr) => (k, arr.toSeq) } } check(forAll(Gen.nonEmptyListOf(genWriteExample))(prop[WriteExample])) diff --git a/dataset/src/test/scala/frameless/forward/WriteTests.scala b/dataset/src/test/scala/frameless/forward/WriteTests.scala index d5a9057cb..4504935c3 100644 --- a/dataset/src/test/scala/frameless/forward/WriteTests.scala +++ b/dataset/src/test/scala/frameless/forward/WriteTests.scala @@ -3,7 +3,7 @@ package frameless import java.util.UUID import org.scalacheck.Prop._ -import org.scalacheck.{Arbitrary, Gen, Prop} +import org.scalacheck.{ Arbitrary, Gen, Prop } class WriteTests extends TypedDatasetSuite { @@ -30,12 +30,19 @@ class WriteTests extends TypedDatasetSuite { val input = TypedDataset.create(data) input.write.csv(filePath) - val dataset = TypedDataset.createUnsafe(sqlContext.read.schema(input.schema).csv(filePath)) + val dataset = TypedDataset.createUnsafe( + sqlContext.read.schema(input.schema).csv(filePath) + ) - dataset.collect().run().groupBy(identity) ?= input.collect().run().groupBy(identity) + dataset.collect().run().groupBy(identity) ?= input + .collect() + .run() + .groupBy(identity) } - check(forAll(Gen.listOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(prop[String])) + check( + forAll(Gen.listOf(Gen.alphaNumStr.suchThat(_.nonEmpty)))(prop[String]) + ) check(forAll(prop[Int] _)) } @@ -45,9 +52,14 @@ class WriteTests extends TypedDatasetSuite { val input = TypedDataset.create(data) input.write.parquet(filePath) - val dataset = TypedDataset.createUnsafe(sqlContext.read.schema(input.schema).parquet(filePath)) + val dataset = TypedDataset.createUnsafe( + sqlContext.read.schema(input.schema).parquet(filePath) + ) - dataset.collect().run().groupBy(identity) ?= input.collect().run().groupBy(identity) + dataset.collect().run().groupBy(identity) ?= input + .collect() + .run() + .groupBy(identity) } check(forAll(Gen.listOf(genWriteExample))(prop[WriteExample])) @@ -56,4 +68,9 @@ class WriteTests extends TypedDatasetSuite { case class Nested(i: Double, v: String) case class OptionFieldsOnly(o1: Option[Int], o2: Option[Nested]) -case class WriteExample(i: Int, s: String, on: Option[Nested], ooo: Option[OptionFieldsOnly]) + +case class WriteExample( + i: Int, + s: String, + on: Option[Nested], + ooo: Option[OptionFieldsOnly]) diff --git a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala index 201d93c63..a491b3816 100644 --- a/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/AggregateFunctionsTests.scala @@ -1,31 +1,37 @@ package frameless package functions -import frameless.{TypedAggregate, TypedColumn} +import frameless.{ TypedAggregate, TypedColumn } import frameless.functions.aggregate._ -import org.apache.spark.sql.{Column, Encoder} -import org.scalacheck.{Gen, Prop} +import org.apache.spark.sql.{ Column, Encoder } +import org.scalacheck.{ Gen, Prop } import org.scalacheck.Prop._ import org.scalatest.exceptions.GeneratorDrivenPropertyCheckFailedException class AggregateFunctionsTests extends TypedDatasetSuite { - def sparkSchema[A: TypedEncoder, U](f: TypedColumn[X1[A], A] => TypedAggregate[X1[A], U]): Prop = { + + def sparkSchema[A: TypedEncoder, U]( + f: TypedColumn[X1[A], A] => TypedAggregate[X1[A], U] + ): Prop = { val df = TypedDataset.create[X1[A]](Nil) val col = f(df.col('a)) val sumDf = df.agg(col) - TypedExpressionEncoder.targetStructType(sumDf.encoder) ?= sumDf.dataset.schema + TypedExpressionEncoder.targetStructType( + sumDf.encoder + ) ?= sumDf.dataset.schema } test("sum") { case class Sum4Tests[A, B](sum: Seq[A] => B) - def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])( - implicit - summable: CatalystSummable[A, Out], - summer: Sum4Tests[A, Out] - ): Prop = { + def prop[A: TypedEncoder, Out: TypedEncoder: Numeric]( + xs: List[A] + )(implicit + summable: CatalystSummable[A, Out], + summer: Sum4Tests[A, Out] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) @@ -33,7 +39,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite { datasetSum match { case x :: Nil => approximatelyEqual(summer.sum(xs), x) - case other => falsified + case other => falsified } } @@ -61,27 +67,31 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("sumDistinct") { case class Sum4Tests[A, B](sum: Seq[A] => B) - def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])( - implicit - summable: CatalystSummable[A, Out], - summer: Sum4Tests[A, Out] - ): Prop = { + def prop[A: TypedEncoder, Out: TypedEncoder: Numeric]( + xs: List[A] + )(implicit + summable: CatalystSummable[A, Out], + summer: Sum4Tests[A, Out] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) - val datasetSum: List[Out] = dataset.agg(sumDistinct(A)).collect().run().toList + val datasetSum: List[Out] = + dataset.agg(sumDistinct(A)).collect().run().toList datasetSum match { case x :: Nil => approximatelyEqual(summer.sum(xs), x) - case other => falsified + case other => falsified } } // Replicate Spark's behaviour : Ints and Shorts are cast to Long // https://github.com/apache/spark/blob/7eb2ca8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L37 implicit def summerLong = Sum4Tests[Long, Long](_.toSet.sum) - implicit def summerInt = Sum4Tests[Int, Long]( x => x.toSet.map((_:Int).toLong).sum) - implicit def summerShort = Sum4Tests[Short, Long](x => x.toSet.map((_:Short).toLong).sum) + implicit def summerInt = + Sum4Tests[Int, Long](x => x.toSet.map((_: Int).toLong).sum) + implicit def summerShort = + Sum4Tests[Short, Long](x => x.toSet.map((_: Short).toLong).sum) check(forAll(prop[Long, Long] _)) check(forAll(prop[Int, Long] _)) @@ -95,33 +105,41 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("avg") { case class Averager4Tests[A, B](avg: Seq[A] => B) - def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])( - implicit - averageable: CatalystAverageable[A, Out], - averager: Averager4Tests[A, Out] - ): Prop = { + def prop[A: TypedEncoder, Out: TypedEncoder: Numeric]( + xs: List[A] + )(implicit + averageable: CatalystAverageable[A, Out], + averager: Averager4Tests[A, Out] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) val datasetAvg: Vector[Out] = dataset.agg(avg(A)).collect().run().toVector if (datasetAvg.size > 2) falsified - else xs match { - case Nil => datasetAvg ?= Vector() - case _ :: _ => datasetAvg.headOption match { - case Some(x) => approximatelyEqual(averager.avg(xs), x) - case None => falsified + else + xs match { + case Nil => datasetAvg ?= Vector() + case _ :: _ => + datasetAvg.headOption match { + case Some(x) => approximatelyEqual(averager.avg(xs), x) + case None => falsified + } } - } } // Replicate Spark's behaviour : If the datatype isn't BigDecimal cast type to Double // https://github.com/apache/spark/blob/7eb2ca8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L50 - implicit def averageDecimal = Averager4Tests[BigDecimal, BigDecimal](as => as.sum/as.size) - implicit def averageDouble = Averager4Tests[Double, Double](as => as.sum/as.size) - implicit def averageLong = Averager4Tests[Long, Double](as => as.map(_.toDouble).sum/as.size) - implicit def averageInt = Averager4Tests[Int, Double](as => as.map(_.toDouble).sum/as.size) - implicit def averageShort = Averager4Tests[Short, Double](as => as.map(_.toDouble).sum/as.size) + implicit def averageDecimal = + Averager4Tests[BigDecimal, BigDecimal](as => as.sum / as.size) + implicit def averageDouble = + Averager4Tests[Double, Double](as => as.sum / as.size) + implicit def averageLong = + Averager4Tests[Long, Double](as => as.map(_.toDouble).sum / as.size) + implicit def averageInt = + Averager4Tests[Int, Double](as => as.map(_.toDouble).sum / as.size) + implicit def averageShort = + Averager4Tests[Short, Double](as => as.map(_.toDouble).sum / as.size) /* under 3.4 an oddity was detected: Falsified after 2 successful property evaluations. @@ -141,20 +159,27 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("stddev and variance") { - def prop[A: TypedEncoder : CatalystVariance : Numeric](xs: List[A]): Prop = { + def prop[A: TypedEncoder: CatalystVariance: Numeric](xs: List[A]): Prop = { val numeric = implicitly[Numeric[A]] val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) - val datasetStdOpt = dataset.agg(stddev(A)).collect().run().toVector.headOption - val datasetVarOpt = dataset.agg(variance(A)).collect().run().toVector.headOption + val datasetStdOpt = + dataset.agg(stddev(A)).collect().run().toVector.headOption + val datasetVarOpt = + dataset.agg(variance(A)).collect().run().toVector.headOption - val std = sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleStdev() - val `var` = sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleVariance() + val std = + sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleStdev() + val `var` = + sc.parallelize(xs.map(implicitly[Numeric[A]].toDouble)).sampleVariance() (datasetStdOpt, datasetVarOpt) match { case (Some(datasetStd), Some(datasetVar)) => - approximatelyEqual(datasetStd, std) && approximatelyEqual(datasetVar, `var`) + approximatelyEqual(datasetStd, std) && approximatelyEqual( + datasetVar, + `var` + ) case _ => proved } } @@ -167,9 +192,17 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("litAggr") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](xs: List[A], b: B, c: C): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + xs: List[A], + b: B, + c: C + ): Prop = { val dataset = TypedDataset.create(xs) - val (r1, rb, rc, rcount) = dataset.agg(count().lit(1), litAggr(b), litAggr(c), count()).collect().run().head + val (r1, rb, rc, rcount) = dataset + .agg(count().lit(1), litAggr(b), litAggr(c), count()) + .collect() + .run() + .head (rcount ?= xs.size.toLong) && (r1 ?= 1) && (rb ?= b) && (rc ?= c) } @@ -203,7 +236,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("max") { - def prop[A: TypedEncoder: CatalystOrdered](xs: List[A])(implicit o: Ordering[A]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered]( + xs: List[A] + )(implicit + o: Ordering[A] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) val datasetMax = dataset.agg(max(A)).collect().run().toList @@ -225,14 +262,18 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val A = dataset.col[Long]('a) val datasetMax = dataset.agg(max(A) * 2).collect().run().headOption - datasetMax ?= (if(xs.isEmpty) None else Some(xs.max * 2)) + datasetMax ?= (if (xs.isEmpty) None else Some(xs.max * 2)) } check(forAll(prop _)) } test("min") { - def prop[A: TypedEncoder: CatalystOrdered](xs: List[A])(implicit o: Ordering[A]): Prop = { + def prop[A: TypedEncoder: CatalystOrdered]( + xs: List[A] + )(implicit + o: Ordering[A] + ): Prop = { val dataset = TypedDataset.create(xs.map(X1(_))) val A = dataset.col[A]('a) @@ -301,8 +342,13 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check { forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] => val tds = TypedDataset.create(xs) - val tdsRes: Seq[(Int, Long)] = tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run() - tdsRes.toMap ?= xs.groupBy(_._1).mapValues(_.map(_._2).distinct.size.toLong).toSeq.toMap + val tdsRes: Seq[(Int, Long)] = + tds.groupBy(tds('_1)).agg(countDistinct(tds('_2))).collect().run() + tdsRes.toMap ?= xs + .groupBy(_._1) + .mapValues(_.map(_._2).distinct.size.toLong) + .toSeq + .toMap } } } @@ -310,7 +356,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite { test("approxCountDistinct") { // Simple version of #approximatelyEqual() // Default maximum estimation error of HyperLogLog in Spark is 5% - def approxEqual(actual: Long, estimated: Long, allowedDeviationPercentile: Double = 0.05): Boolean = { + def approxEqual( + actual: Long, + estimated: Long, + allowedDeviationPercentile: Double = 0.05 + ): Boolean = { val delta: Long = Math.abs(actual - estimated) delta / actual.toDouble < allowedDeviationPercentile * 2 } @@ -319,7 +369,11 @@ class AggregateFunctionsTests extends TypedDatasetSuite { forAll(getLowCardinalityKVPairs) { xs: Vector[(Int, Int)] => val tds = TypedDataset.create(xs) val tdsRes: Seq[(Int, Long, Long)] = - tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2))).collect().run() + tds + .groupBy(tds('_1)) + .agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2))) + .collect() + .run() tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2) } } } @@ -329,18 +383,28 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val tds = TypedDataset.create(xs) val allowedError = 0.1 // 10% val tdsRes: Seq[(Int, Long, Long)] = - tds.groupBy(tds('_1)).agg(countDistinct(tds('_2)), approxCountDistinct(tds('_2), allowedError)).collect().run() + tds + .groupBy(tds('_1)) + .agg( + countDistinct(tds('_2)), + approxCountDistinct(tds('_2), allowedError) + ) + .collect() + .run() tdsRes.forall { case (_, v1, v2) => approxEqual(v1, v2, allowedError) } } } } test("collectList") { - def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = { + def prop[A: TypedEncoder: Ordering](xs: List[X2[A, A]]): Prop = { val tds = TypedDataset.create(xs) - val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run() + val tdsRes: Seq[(A, Vector[A])] = + tds.groupBy(tds('a)).agg(collectList(tds('b))).collect().run() - tdsRes.toMap.map { case (k, v) => k -> v.sorted } ?= xs.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).toVector.sorted } + tdsRes.toMap.map { case (k, v) => k -> v.sorted } ?= xs.groupBy(_.a).map { + case (k, v) => k -> v.map(_.b).toVector.sorted + } } check(forAll(prop[Long] _)) @@ -350,11 +414,14 @@ class AggregateFunctionsTests extends TypedDatasetSuite { } test("collectSet") { - def prop[A: TypedEncoder : Ordering](xs: List[X2[A, A]]): Prop = { + def prop[A: TypedEncoder: Ordering](xs: List[X2[A, A]]): Prop = { val tds = TypedDataset.create(xs) - val tdsRes: Seq[(A, Vector[A])] = tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run() + val tdsRes: Seq[(A, Vector[A])] = + tds.groupBy(tds('a)).agg(collectSet(tds('b))).collect().run() - tdsRes.toMap.map { case (k, v) => k -> v.toSet } ?= xs.groupBy(_.a).map { case (k, v) => k -> v.map(_.b).toSet } + tdsRes.toMap.map { case (k, v) => k -> v.toSet } ?= xs.groupBy(_.a).map { + case (k, v) => k -> v.map(_.b).toSet + } } check(forAll(prop[Long] _)) @@ -379,77 +446,76 @@ class AggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[BigDecimal] _)) } - - def bivariatePropTemplate[A: TypedEncoder, B: TypedEncoder] - ( - xs: List[X3[Int, A, B]] - ) - ( - framelessFun: (TypedColumn[X3[Int, A, B], A], TypedColumn[X3[Int, A, B], B]) => TypedAggregate[X3[Int, A, B], Option[Double]], - sparkFun: (Column, Column) => Column - ) - ( - implicit - encEv: Encoder[(Int, A, B)], - encEv2: Encoder[(Int,Option[Double])], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = { + def bivariatePropTemplate[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(framelessFun: ( + TypedColumn[X3[Int, A, B], A], + TypedColumn[X3[Int, A, B], B] + ) => TypedAggregate[X3[Int, A, B], Option[Double]], + sparkFun: (Column, Column) => Column + )(implicit + encEv: Encoder[(Int, A, B)], + encEv2: Encoder[(Int, Option[Double])], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = { val tds = TypedDataset.create(xs) // Typed implementation of bivar stats function - val tdBivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b), tds('c))).deserialized.map(kv => - (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)) - ).collect().run() + val tdBivar = tds + .groupBy(tds('a)) + .agg(framelessFun(tds('b), tds('c))) + .deserialized + .map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))) + .collect() + .run() val cDF = session.createDataset(xs.map(x => (x.a, x.b, x.c))) // Comparison implementation of bivar stats functions val compBivar = cDF .groupBy(cDF("_1")) .agg(sparkFun(cDF("_2"), cDF("_3"))) - .map( - row => { - val grp = row.getInt(0) - (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) - } - ) + .map(row => { + val grp = row.getInt(0) + (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) + }) // Should be the same tdBivar.toMap ?= compBivar.collect().toMap } - def univariatePropTemplate[A: TypedEncoder] - ( - xs: List[X2[Int, A]] - ) - ( - framelessFun: (TypedColumn[X2[Int, A], A]) => TypedAggregate[X2[Int, A], Option[Double]], - sparkFun: (Column) => Column - ) - ( - implicit - encEv: Encoder[(Int, A)], - encEv2: Encoder[(Int,Option[Double])], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = { + def univariatePropTemplate[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(framelessFun: (TypedColumn[X2[Int, A], A]) => TypedAggregate[ + X2[Int, A], + Option[Double] + ], + sparkFun: (Column) => Column + )(implicit + encEv: Encoder[(Int, A)], + encEv2: Encoder[(Int, Option[Double])], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = { val tds = TypedDataset.create(xs) - //typed implementation of univariate stats function - val tdUnivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b))).deserialized.map(kv => - (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler)) - ).collect().run() + // typed implementation of univariate stats function + val tdUnivar = tds + .groupBy(tds('a)) + .agg(framelessFun(tds('b))) + .deserialized + .map(kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))) + .collect() + .run() val cDF = session.createDataset(xs.map(x => (x.a, x.b))) // Comparison implementation of bivar stats functions val compUnivar = cDF .groupBy(cDF("_1")) .agg(sparkFun(cDF("_2"))) - .map( - row => { - val grp = row.getInt(0) - (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) - } - ) + .map(row => { + val grp = row.getInt(0) + (grp, DoubleBehaviourUtils.nanNullHandler(row.get(1))) + }) // Should be the same tdUnivar.toMap ?= compUnivar.collect().toMap @@ -459,12 +525,16 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])( - implicit - encEv: Encoder[(Int, A, B)], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = bivariatePropTemplate(xs)(corr[A,B,X3[Int, A, B]],org.apache.spark.sql.functions.corr) + def prop[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(implicit + encEv: Encoder[(Int, A, B)], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = bivariatePropTemplate(xs)( + corr[A, B, X3[Int, A, B]], + org.apache.spark.sql.functions.corr + ) check(forAll(prop[Double, Double] _)) check(forAll(prop[Double, Int] _)) @@ -477,12 +547,13 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])( - implicit - encEv: Encoder[(Int, A, B)], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = bivariatePropTemplate(xs)( + def prop[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(implicit + encEv: Encoder[(Int, A, B)], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = bivariatePropTemplate(xs)( covarPop[A, B, X3[Int, A, B]], org.apache.spark.sql.functions.covar_pop ) @@ -498,12 +569,13 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder, B: TypedEncoder](xs: List[X3[Int, A, B]])( - implicit - encEv: Encoder[(Int, A, B)], - evCanBeDoubleA: CatalystCast[A, Double], - evCanBeDoubleB: CatalystCast[B, Double] - ): Prop = bivariatePropTemplate(xs)( + def prop[A: TypedEncoder, B: TypedEncoder]( + xs: List[X3[Int, A, B]] + )(implicit + encEv: Encoder[(Int, A, B)], + evCanBeDoubleA: CatalystCast[A, Double], + evCanBeDoubleB: CatalystCast[B, Double] + ): Prop = bivariatePropTemplate(xs)( covarSamp[A, B, X3[Int, A, B]], org.apache.spark.sql.functions.covar_samp ) @@ -519,11 +591,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( kurtosis[A, X2[Int, A]], org.apache.spark.sql.functions.kurtosis ) @@ -539,11 +612,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( skewness[A, X2[Int, A]], org.apache.spark.sql.functions.skewness ) @@ -559,11 +633,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( stddevPop[A, X2[Int, A]], org.apache.spark.sql.functions.stddev_pop ) @@ -579,11 +654,12 @@ class AggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder](xs: List[X2[Int, A]])( - implicit - encEv: Encoder[(Int, A)], - evCanBeDoubleA: CatalystCast[A, Double] - ): Prop = univariatePropTemplate(xs)( + def prop[A: TypedEncoder]( + xs: List[X2[Int, A]] + )(implicit + encEv: Encoder[(Int, A)], + evCanBeDoubleA: CatalystCast[A, Double] + ): Prop = univariatePropTemplate(xs)( stddevSamp[A, X2[Int, A]], org.apache.spark.sql.functions.stddev_samp ) diff --git a/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala b/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala index e22fe4337..42cccb2eb 100644 --- a/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala +++ b/dataset/src/test/scala/frameless/functions/DateTimeStringBehaviourUtils.scala @@ -3,8 +3,9 @@ package frameless.functions import org.apache.spark.sql.Row object DateTimeStringBehaviourUtils { + val nullHandler: Row => Option[Int] = _.get(0) match { case i: Int => Some(i) - case _ => None + case _ => None } } diff --git a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala index f3a8be581..d19b08ea1 100644 --- a/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala +++ b/dataset/src/test/scala/frameless/functions/DoubleBehaviourUtils.scala @@ -2,19 +2,22 @@ package frameless package functions /** - * Some statistical functions in Spark can result in Double, Double.NaN or Null. - * This tends to break ?= of the property based testing. Use the nanNullHandler function - * here to alleviate this by mapping this NaN and Null to None. This will result in - * functioning comparison again. - */ + * Some statistical functions in Spark can result in Double, Double.NaN or Null. + * This tends to break ?= of the property based testing. Use the nanNullHandler function + * here to alleviate this by mapping this NaN and Null to None. This will result in + * functioning comparison again. + */ object DoubleBehaviourUtils { + // Mapping with this function is needed because spark uses Double.NaN for some semantics in the // correlation function. ?= for prop testing will use == underlying and will break because Double.NaN != Double.NaN - private val nanHandler: Double => Option[Double] = value => if (!value.equals(Double.NaN)) Option(value) else None + private val nanHandler: Double => Option[Double] = value => + if (!value.equals(Double.NaN)) Option(value) else None + // Making sure that null => None and does not result in 0.0d because of row.getAs[Double]'s use of .asInstanceOf val nanNullHandler: Any => Option[Double] = { - case null => None + case null => None case d: Double => nanHandler(d) - case _ => ??? + case _ => ??? } } diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index 470d58e5f..d6dcac14b 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -7,9 +7,14 @@ import java.nio.charset.StandardCharsets import frameless.functions.nonAggregate._ import org.apache.commons.io.FileUtils -import org.apache.spark.sql.{Column, Encoder, SaveMode, functions => sparkFunctions} +import org.apache.spark.sql.{ + Column, + Encoder, + SaveMode, + functions => sparkFunctions +} import org.scalacheck.Prop._ -import org.scalacheck.{Arbitrary, Gen, Prop} +import org.scalacheck.{ Arbitrary, Gen, Prop } import scala.annotation.nowarn @@ -17,30 +22,33 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val testTempFiles = "target/testoutput" object NonNegativeGenerators { + val doubleGen = for { - s <- Gen.chooseNum(1, Int.MaxValue) - e <- Gen.chooseNum(1, Int.MaxValue) + s <- Gen.chooseNum(1, Int.MaxValue) + e <- Gen.chooseNum(1, Int.MaxValue) res: Double = s.toDouble / e.toDouble } yield res - val intGen: Gen[Int] = Gen.chooseNum(1, Int.MaxValue) + val intGen: Gen[Int] = Gen.chooseNum(1, Int.MaxValue) val shortGen: Gen[Short] = Gen.chooseNum(1, Short.MaxValue) - val longGen: Gen[Long] = Gen.chooseNum(1, Long.MaxValue) - val byteGen: Gen[Byte] = Gen.chooseNum(1, Byte.MaxValue) + val longGen: Gen[Long] = Gen.chooseNum(1, Long.MaxValue) + val byteGen: Gen[Byte] = Gen.chooseNum(1, Byte.MaxValue) } object NonNegativeArbitraryNumericValues { import NonNegativeGenerators._ - implicit val arbInt: Arbitrary[Int] = Arbitrary(intGen) - implicit val arbDouble: Arbitrary[Double] = Arbitrary(doubleGen) - implicit val arbLong: Arbitrary[Long] = Arbitrary(longGen) - implicit val arbShort: Arbitrary[Short] = Arbitrary(shortGen) - implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen) + implicit val arbInt: Arbitrary[Int] = Arbitrary(intGen) + implicit val arbDouble: Arbitrary[Double] = Arbitrary(doubleGen) + implicit val arbLong: Arbitrary[Long] = Arbitrary(longGen) + implicit val arbShort: Arbitrary[Short] = Arbitrary(shortGen) + implicit val arbByte: Arbitrary[Byte] = Arbitrary(byteGen) } private val base64Encoder = Base64.getEncoder + private def base64X1String(x1: X1[String]): X1[String] = { - def base64(str: String): String = base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8)) + def base64(str: String): String = + base64Encoder.encodeToString(str.getBytes(StandardCharsets.UTF_8)) x1.copy(a = base64(x1.a)) } @@ -53,9 +61,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder](values: List[X1[A]])( - implicit encX1:Encoder[X1[A]], - catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]], + catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.negate(cDS("a"))) @@ -65,11 +76,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(negate(col)) - .collect() - .run() - .toList + val res = typedDS.select(negate(col)).collect().run().toList res ?= resCompare } @@ -77,7 +84,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Byte, Byte] _)) check(forAll(prop[Short, Short] _)) check(forAll(prop[Int, Int] _)) - check(forAll(prop[Long, Long] _)) + check(forAll(prop[Long, Long] _)) check(forAll(prop[BigDecimal, java.math.BigDecimal] _)) } @@ -85,7 +92,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Boolean]], fromBase: Int, toBase: Int)(implicit encX1:Encoder[X1[Boolean]]) = { + def prop( + values: List[X1[Boolean]], + fromBase: Int, + toBase: Int + )(implicit + encX1: Encoder[X1[Boolean]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS @@ -96,11 +109,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(not(col)) - .collect() - .run() - .toList + val res = typedDS.select(not(col)).collect().run().toList res ?= resCompare } @@ -112,7 +121,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[String]], fromBase: Int, toBase: Int)(implicit encX1:Encoder[X1[String]]) = { + def prop( + values: List[X1[String]], + fromBase: Int, + toBase: Int + )(implicit + encX1: Encoder[X1[String]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS @@ -123,11 +138,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(conv(col, fromBase, toBase)) - .collect() - .run() - .toList + val res = + typedDS.select(conv(col, fromBase, toBase)).collect().run().toList res ?= resCompare } @@ -139,7 +151,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.degrees(cDS("a"))) @@ -149,11 +165,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(degrees(col)) - .collect() - .run() - .toList + val res = typedDS.select(degrees(col)).collect().run().toList res ?= resCompare } @@ -161,12 +173,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Byte] _)) check(forAll(prop[Short] _)) check(forAll(prop[Int] _)) - check(forAll(prop[Long] _)) + check(forAll(prop[Long] _)) check(forAll(prop[BigDecimal] _)) } - def propBitShift[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]]) - (typedCol: TypedColumn[X1[A], B], sparkFunc: (Column,Int) => Column, numBits: Int): Prop = { + def propBitShift[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + typedDS: TypedDataset[X1[A]] + )(typedCol: TypedColumn[X1[A], B], + sparkFunc: (Column, Int) => Column, + numBits: Int + ): Prop = { val spark = session import spark.implicits._ @@ -176,11 +192,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .toList - val res = typedDS - .select(typedCol) - .collect() - .run() - .toList + val res = typedDS.select(typedCol).collect().run().toList res ?= resCompare } @@ -190,11 +202,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ @nowarn // supress sparkFunctions.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]], numBits: Int) - (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]], + numBits: Int + )(implicit + catalystBitShift: CatalystBitShift[A, B], + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftRightUnsigned, numBits) + propBitShift(typedDS)( + shiftRightUnsigned(typedDS('a), numBits), + sparkFunctions.shiftRightUnsigned, + numBits + ) } check(forAll(prop[Byte, Int] _)) @@ -209,11 +229,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ @nowarn // supress sparkFunctions.shiftRight call which is used to maintain Spark 3.1.x backwards compat - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]], numBits: Int) - (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]], + numBits: Int + )(implicit + catalystBitShift: CatalystBitShift[A, B], + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftRight, numBits) + propBitShift(typedDS)( + shiftRight(typedDS('a), numBits), + sparkFunctions.shiftRight, + numBits + ) } check(forAll(prop[Byte, Int] _)) @@ -228,11 +256,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ @nowarn // supress sparkFunctions.shiftLeft call which is used to maintain Spark 3.1.x backwards compat - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]], numBits: Int) - (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]], + numBits: Int + )(implicit + catalystBitShift: CatalystBitShift[A, B], + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftLeft, numBits) + propBitShift(typedDS)( + shiftLeft(typedDS('a), numBits), + sparkFunctions.shiftLeft, + numBits + ) } check(forAll(prop[Byte, Int] _)) @@ -246,27 +282,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]])( - implicit catalystAbsolute: CatalystRound[A, B], encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystRound[A, B], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.ceil(cDS("a"))) .map(_.getAs[B](0)) .collect() - .toList.map{ - case bigDecimal : java.math.BigDecimal => bigDecimal.setScale(0) - case other => other - }.asInstanceOf[List[B]] - + .toList + .map { + case bigDecimal: java.math.BigDecimal => bigDecimal.setScale(0) + case other => other + } + .asInstanceOf[List[B]] val typedDS = TypedDataset.create(values) - val res = typedDS - .select(ceil(typedDS('a))) - .collect() - .run() - .toList + val res = typedDS.select(ceil(typedDS('a))).collect().run().toList res ?= resCompare } @@ -282,20 +317,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { Seq(224, 256, 384, 512).map { numBits => val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.sha2(cDS("a"), numBits)) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(sha2(typedDS('a), numBits)) .collect() - .run() .toList + + val typedDS = TypedDataset.create(values) + val res = + typedDS.select(sha2(typedDS('a), numBits)).collect().run().toList res ?= resCompare }.reduce(_ && _) } @@ -307,20 +344,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.sha1(cDS("a"))) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(sha1(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(sha1(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -331,7 +369,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[Array[Byte]]])(implicit encX1: Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.crc32(cDS("a"))) @@ -340,11 +382,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .toList val typedDS = TypedDataset.create(values) - val res = typedDS - .select(crc32(typedDS('a))) - .collect() - .run() - .toList + val res = typedDS.select(crc32(typedDS('a))).collect().run().toList res ?= resCompare } @@ -356,27 +394,26 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] - (values: List[X1[A]])( - implicit catalystAbsolute: CatalystRound[A, B], encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystRound[A, B], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.floor(cDS("a"))) .map(_.getAs[B](0)) .collect() - .toList.map{ - case bigDecimal : java.math.BigDecimal => bigDecimal.setScale(0) - case other => other - }.asInstanceOf[List[B]] - + .toList + .map { + case bigDecimal: java.math.BigDecimal => bigDecimal.setScale(0) + case other => other + } + .asInstanceOf[List[B]] val typedDS = TypedDataset.create(values) - val res = typedDS - .select(floor(typedDS('a))) - .collect() - .run() - .toList + val res = typedDS.select(floor(typedDS('a))).collect().run().toList res ?= resCompare } @@ -387,35 +424,35 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[BigDecimal, java.math.BigDecimal] _)) } - test("abs big decimal") { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder] - (values: List[X1[A]]) - ( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B], - encX1:Encoder[X1[A]] - )= { - val cDS = session.createDataset(values) - val resCompare = cDS - .select(sparkFunctions.abs(cDS("a"))) - .map(_.getAs[B](0)) - .collect().toList + def prop[A: TypedEncoder: Encoder, B: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, B], + encX1: Encoder[X1[A]] + ) = { + val cDS = session.createDataset(values) + val resCompare = cDS + .select(sparkFunctions.abs(cDS("a"))) + .map(_.getAs[B](0)) + .collect() + .toList - val typedDS = TypedDataset.create(values) - val col = typedDS('a) - val res = typedDS - .select( - abs(col) - ) - .collect() - .run() - .toList + val typedDS = TypedDataset.create(values) + val col = typedDS('a) + val res = typedDS + .select( + abs(col) + ) + .collect() + .run() + .toList - res ?= resCompare - } + res ?= resCompare + } check(forAll(prop[BigDecimal, java.math.BigDecimal] _)) } @@ -424,26 +461,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder] - (values: List[X1[A]]) - ( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, A], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.abs(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(abs(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(abs(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -453,36 +486,43 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - def propTrigonometric[A: CatalystNumeric: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]]) - (typedCol: TypedColumn[X1[A], Double], sparkFunc: Column => Column): Prop = { - val spark = session - import spark.implicits._ + def propTrigonometric[A: CatalystNumeric: TypedEncoder: Encoder]( + typedDS: TypedDataset[X1[A]] + )(typedCol: TypedColumn[X1[A], Double], + sparkFunc: Column => Column + ): Prop = { + val spark = session + import spark.implicits._ - val resCompare = typedDS.dataset - .select(sparkFunc($"a")) - .map(_.getAs[Double](0)) - .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + val resCompare = typedDS.dataset + .select(sparkFunc($"a")) + .map(_.getAs[Double](0)) + .map(DoubleBehaviourUtils.nanNullHandler) + .collect() + .toList - val res = typedDS - .select(typedCol) - .deserialized - .map(DoubleBehaviourUtils.nanNullHandler) - .collect() - .run() - .toList + val res = typedDS + .select(typedCol) + .deserialized + .map(DoubleBehaviourUtils.nanNullHandler) + .collect() + .run() + .toList - res ?= resCompare + res ?= resCompare } test("cos") { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(cos(typedDS('a)), sparkFunctions.cos) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(cos(typedDS('a)), sparkFunctions.cos) } check(forAll(prop[Int] _)) @@ -497,10 +537,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(cosh(typedDS('a)), sparkFunctions.cosh) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(cosh(typedDS('a)), sparkFunctions.cosh) } check(forAll(prop[Int] _)) @@ -515,10 +558,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(acos(typedDS('a)), sparkFunctions.acos) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(acos(typedDS('a)), sparkFunctions.acos) } check(forAll(prop[Int] _)) @@ -529,16 +575,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - - test("signum") { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(signum(typedDS('a)), sparkFunctions.signum) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(signum(typedDS('a)), sparkFunctions.signum) } check(forAll(prop[Int] _)) @@ -553,10 +600,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(sin(typedDS('a)), sparkFunctions.sin) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(sin(typedDS('a)), sparkFunctions.sin) } check(forAll(prop[Int] _)) @@ -571,10 +621,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(sinh(typedDS('a)), sparkFunctions.sinh) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(sinh(typedDS('a)), sparkFunctions.sinh) } check(forAll(prop[Int] _)) @@ -589,10 +642,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(asin(typedDS('a)), sparkFunctions.asin) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(asin(typedDS('a)), sparkFunctions.asin) } check(forAll(prop[Int] _)) @@ -607,10 +663,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(tan(typedDS('a)), sparkFunctions.tan) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(tan(typedDS('a)), sparkFunctions.tan) } check(forAll(prop[Int] _)) @@ -625,10 +684,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]]) - (implicit encX1:Encoder[X1[A]]) = { - val typedDS = TypedDataset.create(values) - propTrigonometric(typedDS)(tanh(typedDS('a)), sparkFunctions.tanh) + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { + val typedDS = TypedDataset.create(values) + propTrigonometric(typedDS)(tanh(typedDS('a)), sparkFunctions.tanh) } check(forAll(prop[Int] _)) @@ -639,48 +701,46 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - /* - * Currently not all Collection types play nice with the Encoders. - * This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved. - * - * [[https://issues.apache.org/jira/browse/SPARK-18891]] - * [[https://issues.apache.org/jira/browse/SPARK-21204]] - */ - test("arrayContains"){ + /* + * Currently not all Collection types play nice with the Encoders. + * This test needs to be readressed and Set readded to the Collection Typeclass once these issues are resolved. + * + * [[https://issues.apache.org/jira/browse/SPARK-18891]] + * [[https://issues.apache.org/jira/browse/SPARK-21204]] + */ + test("arrayContains") { val spark = session import spark.implicits._ val listLength = 10 val idxs = Stream.continually(Range(0, listLength)).flatten.toIterator - abstract class Nth[A, C[A]:CatalystCollection] { + abstract class Nth[A, C[A]: CatalystCollection] { - def nth(c:C[A], idx:Int):A + def nth(c: C[A], idx: Int): A } - implicit def deriveListNth[A] : Nth[A, List] = new Nth[A, List] { + implicit def deriveListNth[A]: Nth[A, List] = new Nth[A, List] { override def nth(c: List[A], idx: Int): A = c(idx) } - implicit def deriveSeqNth[A] : Nth[A, Seq] = new Nth[A, Seq] { + implicit def deriveSeqNth[A]: Nth[A, Seq] = new Nth[A, Seq] { override def nth(c: Seq[A], idx: Int): A = c(idx) } - implicit def deriveVectorNth[A] : Nth[A, Vector] = new Nth[A, Vector] { + implicit def deriveVectorNth[A]: Nth[A, Vector] = new Nth[A, Vector] { override def nth(c: Vector[A], idx: Int): A = c(idx) } - implicit def deriveArrayNth[A] : Nth[A, Array] = new Nth[A, Array] { + implicit def deriveArrayNth[A]: Nth[A, Array] = new Nth[A, Array] { override def nth(c: Array[A], idx: Int): A = c(idx) } - - def prop[C[_] : CatalystCollection] - ( + def prop[C[_]: CatalystCollection]( values: C[Int], - shouldBeIn:Boolean) - ( - implicit nth:Nth[Int, C], + shouldBeIn: Boolean + )(implicit + nth: Nth[Int, C], encEv: Encoder[C[Int]], tEncEv: TypedEncoder[C[Int]] ) = { @@ -691,7 +751,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = cDS .select(sparkFunctions.array_contains(cDS("value"), contained)) .map(_.getAs[Boolean](0)) - .collect().toList + .collect() + .toList val typedDS = TypedDataset.create(List(X1(values))) val res = typedDS @@ -705,10 +766,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check( forAll( - Gen.listOfN(listLength, Gen.choose(0,100)), - Gen.oneOf(true,false) - ) - (prop[List]) + Gen.listOfN(listLength, Gen.choose(0, 100)), + Gen.oneOf(true, false) + )(prop[List]) ) /*check( Looks like there is no Typed Encoder for Seq type yet @@ -721,18 +781,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check( forAll( - Gen.listOfN(listLength, Gen.choose(0,100)).map(_.toVector), - Gen.oneOf(true,false) - ) - (prop[Vector]) + Gen.listOfN(listLength, Gen.choose(0, 100)).map(_.toVector), + Gen.oneOf(true, false) + )(prop[Vector]) ) check( forAll( - Gen.listOfN(listLength, Gen.choose(0,100)).map(_.toArray), - Gen.oneOf(true,false) - ) - (prop[Array]) + Gen.listOfN(listLength, Gen.choose(0, 100)).map(_.toArray), + Gen.oneOf(true, false) + )(prop[Array]) ) } @@ -740,14 +798,19 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder] - (na: A, values: List[X1[A]])(implicit encX1: Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + na: A, + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(X1(na) :: values) val resCompare = cDS .select(sparkFunctions.atan(cDS("a"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -758,13 +821,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - val aggrTyped = typedDS.agg(atan( - frameless.functions.aggregate.first(typedDS('a))) - ).firstOption().run().get + val aggrTyped = typedDS + .agg(atan(frameless.functions.aggregate.first(typedDS('a)))) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.atan(sparkFunctions.first("a")).as[Double] - ).first() + val aggrSpark = cDS + .select( + sparkFunctions.atan(sparkFunctions.first("a")).as[Double] + ) + .first() (res ?= resCompare).&&(aggrTyped ?= aggrSpark) } @@ -781,16 +848,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder, - B: CatalystNumeric : TypedEncoder : Encoder](na: X2[A, B], values: List[X2[A, B]]) - (implicit encEv: Encoder[X2[A,B]]) = { + def prop[ + A: CatalystNumeric: TypedEncoder: Encoder, + B: CatalystNumeric: TypedEncoder: Encoder + ](na: X2[A, B], + values: List[X2[A, B]] + )(implicit + encEv: Encoder[X2[A, B]] + ) = { val cDS = session.createDataset(na +: values) val resCompare = cDS .select(sparkFunctions.atan2(cDS("a"), cDS("b"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList - + .collect() + .toList val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -801,19 +873,28 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - val aggrTyped = typedDS.agg(atan2( - frameless.functions.aggregate.first(typedDS('a)), - frameless.functions.aggregate.first(typedDS('b))) - ).firstOption().run().get + val aggrTyped = typedDS + .agg( + atan2( + frameless.functions.aggregate.first(typedDS('a)), + frameless.functions.aggregate.first(typedDS('b)) + ) + ) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.atan2(sparkFunctions.first("a"),sparkFunctions.first("b")).as[Double] - ).first() + val aggrSpark = cDS + .select( + sparkFunctions + .atan2(sparkFunctions.first("a"), sparkFunctions.first("b")) + .as[Double] + ) + .first() (res ?= resCompare).&&(aggrTyped ?= aggrSpark) } - check(forAll(prop[Int, Long] _)) check(forAll(prop[Long, Int] _)) check(forAll(prop[Short, Byte] _)) @@ -826,15 +907,20 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder] - (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + na: X1[A], + value: List[X1[A]], + lit: Double + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(na +: value) val resCompare = cDS .select(sparkFunctions.atan2(lit, cDS("a"))) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList - + .collect() + .toList val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -845,14 +931,17 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - val aggrTyped = typedDS.agg(atan2( - lit, - frameless.functions.aggregate.first(typedDS('a))) - ).firstOption().run().get + val aggrTyped = typedDS + .agg(atan2(lit, frameless.functions.aggregate.first(typedDS('a)))) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double] - ).first() + val aggrSpark = cDS + .select( + sparkFunctions.atan2(lit, sparkFunctions.first("a")).as[Double] + ) + .first() (res ?= resCompare).&&(aggrTyped ?= aggrSpark) } @@ -869,15 +958,20 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder] - (na: X1[A], value: List[X1[A]], lit:Double)(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + na: X1[A], + value: List[X1[A]], + lit: Double + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(na +: value) val resCompare = cDS .select(sparkFunctions.atan2(cDS("a"), lit)) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList - + .collect() + .toList val typedDS = TypedDataset.create(cDS) val res = typedDS @@ -888,19 +982,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - val aggrTyped = typedDS.agg(atan2( - frameless.functions.aggregate.first(typedDS('a)), - lit) - ).firstOption().run().get + val aggrTyped = typedDS + .agg(atan2(frameless.functions.aggregate.first(typedDS('a)), lit)) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double] - ).first() + val aggrSpark = cDS + .select( + sparkFunctions.atan2(sparkFunctions.first("a"), lit).as[Double] + ) + .first() (res ?= resCompare).&&(aggrTyped ?= aggrSpark) } - check(forAll(prop[Int] _)) check(forAll(prop[Long] _)) check(forAll(prop[Short] _)) @@ -909,9 +1005,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Double] _)) } - def mathProp[A: CatalystNumeric: TypedEncoder : Encoder](typedDS: TypedDataset[X1[A]])( - typedCol: TypedColumn[X1[A], Double], sparkFunc: Column => Column - ): Prop = { + def mathProp[A: CatalystNumeric: TypedEncoder: Encoder]( + typedDS: TypedDataset[X1[A]] + )(typedCol: TypedColumn[X1[A], Double], + sparkFunc: Column => Column + ): Prop = { val spark = session import spark.implicits._ @@ -919,7 +1017,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunc($"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(typedCol) @@ -936,7 +1035,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(sqrt(typedDS('a)), sparkFunctions.sqrt) } @@ -953,7 +1056,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(cbrt(typedDS('a)), sparkFunctions.cbrt) } @@ -970,7 +1077,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(exp(typedDS('a)), sparkFunctions.exp) } @@ -987,7 +1098,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]]): Prop = { + def prop[A: TypedEncoder: Encoder](values: List[X1[A]]): Prop = { val spark = session import spark.implicits._ @@ -996,14 +1107,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = typedDS.dataset .select(sparkFunctions.md5($"a")) .map(_.getAs[String](0)) - .collect().toList - - val res = typedDS - .select(md5(typedDS('a))) .collect() - .run() .toList + val res = typedDS.select(md5(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1022,14 +1130,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = typedDS.dataset .select(sparkFunctions.factorial($"a")) .map(_.getAs[Long](0)) - .collect().toList - - val res = typedDS - .select(factorial(typedDS('a))) .collect() - .run() .toList + val res = typedDS.select(factorial(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1040,24 +1145,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(round(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(round(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1071,25 +1177,27 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"))) .map(_.getAs[java.math.BigDecimal](0)) .collect() - .toList.map(_.setScale(0)) + .toList + .map(_.setScale(0)) val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(round(col)) - .collect() - .run() - .toList + val res = typedDS.select(round(col)).collect().run().toList res ?= resCompare } @@ -1101,24 +1209,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"), 1)) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(round(typedDS('a), 1)) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(round(typedDS('a), 1)).collect().run().toList + res ?= resCompare } @@ -1132,25 +1241,27 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.round(cDS("a"), 0)) .map(_.getAs[java.math.BigDecimal](0)) .collect() - .toList.map(_.setScale(0)) + .toList + .map(_.setScale(0)) val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(round(col, 0)) - .collect() - .run() - .toList + val res = typedDS.select(round(col, 0)).collect().run().toList res ?= resCompare } @@ -1162,24 +1273,25 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], - encX1: Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bround(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bround(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bround(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1187,31 +1299,33 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Long] _)) check(forAll(prop[Short] _)) check(forAll(prop[Double] _)) - } + } test("bround big decimal") { val spark = session import spark.implicits._ - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] - ) = { + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bround(cDS("a"))) .map(_.getAs[java.math.BigDecimal](0)) .collect() - .toList.map(_.setScale(0)) + .toList + .map(_.setScale(0)) val typedDS = TypedDataset.create(values) val col = typedDS('a) - val res = typedDS - .select(bround(col)) - .collect() - .run() - .toList + val res = typedDS.select(bround(col)).collect().run().toList res ?= resCompare } @@ -1219,63 +1333,66 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[BigDecimal] _)) } - test("bround with scale") { - val spark = session - import spark.implicits._ + test("bround with scale") { + val spark = session + import spark.implicits._ - def prop[A: TypedEncoder : Encoder](values: List[X1[A]])( - implicit catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[A, A], + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystNumericWithJavaBigDecimal: CatalystNumericWithJavaBigDecimal[ + A, + A + ], encX1: Encoder[X1[A]] ) = { - val cDS = session.createDataset(values) - val resCompare = cDS - .select(sparkFunctions.bround(cDS("a"), 1)) - .map(_.getAs[A](0)) - .collect().toList - - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bround(typedDS('a), 1)) - .collect() - .run() - .toList + val cDS = session.createDataset(values) + val resCompare = cDS + .select(sparkFunctions.bround(cDS("a"), 1)) + .map(_.getAs[A](0)) + .collect() + .toList - res ?= resCompare - } + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bround(typedDS('a), 1)).collect().run().toList - check(forAll(prop[Int] _)) - check(forAll(prop[Long] _)) - check(forAll(prop[Short] _)) - check(forAll(prop[Double] _)) + res ?= resCompare } - test("bround big decimal with scale") { - val spark = session - import spark.implicits._ + check(forAll(prop[Int] _)) + check(forAll(prop[Long] _)) + check(forAll(prop[Short] _)) + check(forAll(prop[Double] _)) + } - def prop[A: TypedEncoder: Encoder](values: List[X1[A]])( - implicit catalystAbsolute: CatalystNumericWithJavaBigDecimal[A, java.math.BigDecimal], - encX1:Encoder[X1[A]] + test("bround big decimal with scale") { + val spark = session + import spark.implicits._ + + def prop[A: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + catalystAbsolute: CatalystNumericWithJavaBigDecimal[ + A, + java.math.BigDecimal + ], + encX1: Encoder[X1[A]] ) = { - val cDS = session.createDataset(values) - - val resCompare = cDS - .select(sparkFunctions.bround(cDS("a"), 0)) - .map(_.getAs[java.math.BigDecimal](0)) - .collect() - .toList.map(_.setScale(0)) - - val typedDS = TypedDataset.create(values) - val col = typedDS('a) - val res = typedDS - .select(bround(col, 0)) - .collect() - .run() - .toList - - res ?= resCompare - } + val cDS = session.createDataset(values) + + val resCompare = cDS + .select(sparkFunctions.bround(cDS("a"), 0)) + .map(_.getAs[java.math.BigDecimal](0)) + .collect() + .toList + .map(_.setScale(0)) + + val typedDS = TypedDataset.create(values) + val col = typedDS('a) + val res = typedDS.select(bround(col, 0)).collect().run().toList + + res ?= resCompare + } check(forAll(prop[BigDecimal] _)) } @@ -1285,10 +1402,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X1[A]], - base: Double - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]], + base: Double + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1297,7 +1414,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.log(base, $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(log(base, typedDS('a))) @@ -1322,7 +1440,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log(typedDS('a)), sparkFunctions.log) } @@ -1339,7 +1461,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log2(typedDS('a)), sparkFunctions.log2) } @@ -1356,7 +1482,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log1p(typedDS('a)), sparkFunctions.log1p) } @@ -1373,7 +1503,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric : TypedEncoder : Encoder](values: List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val typedDS = TypedDataset.create(values) mathProp(typedDS)(log10(typedDS('a)), sparkFunctions.log10) } @@ -1389,20 +1523,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values:List[X1[Array[Byte]]])(implicit encX1:Encoder[X1[Array[Byte]]]) = { + def prop( + values: List[X1[Array[Byte]]] + )(implicit + encX1: Encoder[X1[Array[Byte]]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.base64(cDS("a"))) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(base64(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(base64(typedDS('a))).collect().run().toList + val backAndForth = typedDS .select(base64(unbase64(base64(typedDS('a))))) .collect() @@ -1419,10 +1554,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X1[A]], - base: Double - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]], + base: Double + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1431,7 +1566,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.hypot(base, $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res2 = typedDS .select(hypot(typedDS('a), base)) @@ -1463,9 +1599,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X2[A, A]] - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X2[A, A]] + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1474,7 +1610,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.hypot($"b", $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(hypot(typedDS('b), typedDS('a))) @@ -1498,10 +1635,10 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X1[A]], - base: Double - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X1[A]], + base: Double + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1510,7 +1647,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.pow(base, $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(pow(base, typedDS('a))) @@ -1524,7 +1662,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.pow($"a", base)) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res2 = typedDS .select(pow(typedDS('a), base)) @@ -1534,7 +1673,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .run() .toList - (res ?= resCompare) && (res2 ?= resCompare2) + (res ?= resCompare) && (res2 ?= resCompare2) } check(forAll(prop[Int] _)) @@ -1548,9 +1687,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X2[A, A]] - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X2[A, A]] + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1559,7 +1698,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .select(sparkFunctions.pow($"b", $"a")) .map(_.getAs[Double](0)) .map(DoubleBehaviourUtils.nanNullHandler) - .collect().toList + .collect() + .toList val res = typedDS .select(pow(typedDS('b), typedDS('a))) @@ -1584,9 +1724,9 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { import spark.implicits._ import NonNegativeArbitraryNumericValues._ - def prop[A: CatalystNumeric: TypedEncoder : Encoder]( - values: List[X2[A, A]] - ): Prop = { + def prop[A: CatalystNumeric: TypedEncoder: Encoder]( + values: List[X2[A, A]] + ): Prop = { val spark = session import spark.implicits._ val typedDS = TypedDataset.create(values) @@ -1594,14 +1734,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val resCompare = typedDS.dataset .select(sparkFunctions.pmod($"b", $"a")) .map(_.getAs[A](0)) - .collect().toList - - val res = typedDS - .select(pmod(typedDS('b), typedDS('a))) .collect() - .run() .toList + val res = + typedDS.select(pmod(typedDS('b), typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1616,71 +1754,73 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(values: List[X1[String]])(implicit encX1: Encoder[X1[String]]) = { + def prop( + values: List[X1[String]] + )(implicit + encX1: Encoder[X1[String]] + ) = { val valuesBase64 = values.map(base64X1String) val cDS = session.createDataset(valuesBase64) val resCompare = cDS .select(sparkFunctions.unbase64(cDS("a"))) .map(_.getAs[Array[Byte]](0)) - .collect().toList - - val typedDS = TypedDataset.create(valuesBase64) - val res = typedDS - .select(unbase64(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(valuesBase64) + val res = typedDS.select(unbase64(typedDS('a))).collect().run().toList + res.map(_.toList) ?= resCompare.map(_.toList) } check(forAll(prop _)) } - test("bin"){ + test("bin") { val spark = session import spark.implicits._ - def prop(values:List[X1[Long]])(implicit encX1:Encoder[X1[Long]]) = { + def prop( + values: List[X1[Long]] + )(implicit + encX1: Encoder[X1[Long]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bin(cDS("a"))) .map(_.getAs[String](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bin(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bin(typedDS('a))).collect().run().toList + res ?= resCompare } check(forAll(prop _)) } - test("bitwiseNOT"){ + test("bitwiseNOT") { val spark = session import spark.implicits._ @nowarn // supress sparkFunctions.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat - def prop[A: CatalystBitwise : TypedEncoder : Encoder] - (values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { + def prop[A: CatalystBitwise: TypedEncoder: Encoder]( + values: List[X1[A]] + )(implicit + encX1: Encoder[X1[A]] + ) = { val cDS = session.createDataset(values) val resCompare = cDS .select(sparkFunctions.bitwiseNOT(cDS("a"))) .map(_.getAs[A](0)) - .collect().toList - - val typedDS = TypedDataset.create(values) - val res = typedDS - .select(bitwiseNOT(typedDS('a))) .collect() - .run() .toList + val typedDS = TypedDataset.create(values) + val res = typedDS.select(bitwiseNOT(typedDS('a))).collect().run().toList + res ?= resCompare } @@ -1694,11 +1834,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A : TypedEncoder]( - toFile1: List[X1[A]], - toFile2: List[X1[A]], - inMem: List[X1[A]] - )(implicit x2Gen: Encoder[X2[A, String]], x3Gen: Encoder[X3[A, String, String]]) = { + def prop[A: TypedEncoder]( + toFile1: List[X1[A]], + toFile2: List[X1[A]], + inMem: List[X1[A]] + )(implicit + x2Gen: Encoder[X2[A, String]], + x3Gen: Encoder[X3[A, String, String]] + ) = { val file1Path = testTempFiles + "/file1" val file2Path = testTempFiles + "/file2" @@ -1719,7 +1862,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val unioned = ds1.union(ds2).union(ds3) - val withFileName = unioned.withColumn[X3[A, String, String]](inputFileName[X2[A, String]]()) + val withFileName = unioned + .withColumn[X3[A, String, String]](inputFileName[X2[A, String]]()) .collect() .run() .toVector @@ -1727,10 +1871,13 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val grouped = withFileName.groupBy(_.b).mapValues(_.map(_.c).toSet) grouped.foldLeft(passed) { (p, g) => - p && secure { g._1 match { - case "" => g._2.head == "" //Empty string if didn't come from file - case f => g._2.forall(_.contains(f)) - }}} + p && secure { + g._1 match { + case "" => g._2.head == "" // Empty string if didn't come from file + case f => g._2.forall(_.contains(f)) + } + } + } } check(forAll(prop[String] _)) @@ -1740,17 +1887,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A : TypedEncoder](xs: List[X1[A]])(implicit x2en: Encoder[X2[A, Long]]) = { + def prop[A: TypedEncoder]( + xs: List[X1[A]] + )(implicit + x2en: Encoder[X2[A, Long]] + ) = { val ds = TypedDataset.create(xs) - val result = ds.withColumn[X2[A, Long]](monotonicallyIncreasingId()) + val result = ds + .withColumn[X2[A, Long]](monotonicallyIncreasingId()) .collect() .run() .toVector val ids = result.map(_.b) (ids.toSet.size ?= ids.length) && - (ids.sorted ?= ids) + (ids.sorted ?= ids) } check(forAll(prop[String] _)) @@ -1760,13 +1912,22 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop[A : TypedEncoder : Encoder] - (condition1: Boolean, condition2: Boolean, value1: A, value2: A, otherwise: A) = { - val ds = TypedDataset.create(X5(condition1, condition2, value1, value2, otherwise) :: Nil) + def prop[A: TypedEncoder: Encoder]( + condition1: Boolean, + condition2: Boolean, + value1: A, + value2: A, + otherwise: A + ) = { + val ds = TypedDataset.create( + X5(condition1, condition2, value1, value2, otherwise) :: Nil + ) - val untypedWhen = ds.toDF() + val untypedWhen = ds + .toDF() .select( - sparkFunctions.when(sparkFunctions.col("a"), sparkFunctions.col("c")) + sparkFunctions + .when(sparkFunctions.col("a"), sparkFunctions.col("c")) .when(sparkFunctions.col("b"), sparkFunctions.col("d")) .otherwise(sparkFunctions.col("e")) ) @@ -1776,9 +1937,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val typedWhen = ds .select( - when(ds('a), ds('c)) - .when(ds('b), ds('d)) - .otherwise(ds('e)) + when(ds('a), ds('c)).when(ds('b), ds('d)).otherwise(ds('e)) ) .collect() .run() @@ -1800,17 +1959,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.ascii($"a")) .map(_.getAs[Int](0)) .collect() .toVector - val typed = ds - .select(ascii(ds('a))) - .collect() - .run() - .toVector + val typed = ds.select(ascii(ds('a))).collect().run().toVector typed ?= sparkResult }) @@ -1828,19 +1984,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.concat($"a", $"b")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(concat(ds('a), ds('b))) - .collect() - .run() - .toVector + val typed = ds.select(concat(ds('a), ds('b))).collect().run().toVector - (typed ?= sparkResult).&&(typed ?= values.map(x => s"${x.a}${x.b}").toVector) + (typed ?= sparkResult).&&( + typed ?= values.map(x => s"${x.a}${x.b}").toVector + ) }) } @@ -1855,10 +2010,18 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val td = ds.agg(concat(first(ds('a)),first(ds('b)))).collect().run().toVector - val spark = ds.dataset.select(sparkFunctions.concat( - sparkFunctions.first($"a").as[String], - sparkFunctions.first($"b").as[String])).as[String].collect().toVector + val td = + ds.agg(concat(first(ds('a)), first(ds('b)))).collect().run().toVector + val spark = ds.dataset + .select( + sparkFunctions.concat( + sparkFunctions.first($"a").as[String], + sparkFunctions.first($"b").as[String] + ) + ) + .as[String] + .collect() + .toVector td ?= spark }) } @@ -1875,17 +2038,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.concat_ws(",", $"a", $"b")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(concatWs(",", ds('a), ds('b))) - .collect() - .run() - .toVector + val typed = + ds.select(concatWs(",", ds('a), ds('b))).collect().run().toVector typed ?= sparkResult }) @@ -1902,11 +2063,23 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(pairs) { values: List[X2[String, String]] => val ds = TypedDataset.create(values) - val td = ds.agg(concatWs(",",first(ds('a)),first(ds('b)), last(ds('b)))).collect().run().toVector - val spark = ds.dataset.select(sparkFunctions.concat_ws(",", - sparkFunctions.first($"a").as[String], - sparkFunctions.first($"b").as[String], - sparkFunctions.last($"b").as[String])).as[String].collect().toVector + val td = ds + .agg(concatWs(",", first(ds('a)), first(ds('b)), last(ds('b)))) + .collect() + .run() + .toVector + val spark = ds.dataset + .select( + sparkFunctions.concat_ws( + ",", + sparkFunctions.first($"a").as[String], + sparkFunctions.first($"b").as[String], + sparkFunctions.last($"b").as[String] + ) + ) + .as[String] + .collect() + .toVector td ?= spark }) } @@ -1917,17 +2090,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(Gen.nonEmptyListOf(Gen.alphaStr)) { values: List[String] => val ds = TypedDataset.create(values.map(x => X1(x + values.head))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.instr($"a", values.head)) .map(_.getAs[Int](0)) .collect() .toVector - val typed = ds - .select(instr(ds('a), values.head)) - .collect() - .run() - .toVector + val typed = ds.select(instr(ds('a), values.head)).collect().run().toVector typed ?= sparkResult }) @@ -1939,17 +2109,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.length($"a")) .map(_.getAs[Int](0)) .collect() .toVector - val typed = ds - .select(length(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(length(ds[String]('a))).collect().run().toVector (typed ?= sparkResult).&&(values.map(_.a.length).toVector ?= typed) }) @@ -1961,26 +2128,43 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { (na: X1[String], values: List[X1[String]]) => val ds = TypedDataset.create(na +: values) - val sparkResult = ds.toDF() - .select(sparkFunctions.levenshtein($"a", sparkFunctions.concat($"a",sparkFunctions.lit("Hello")))) + val sparkResult = ds + .toDF() + .select( + sparkFunctions.levenshtein( + $"a", + sparkFunctions.concat($"a", sparkFunctions.lit("Hello")) + ) + ) .map(_.getAs[Int](0)) .collect() .toVector val typed = ds - .select(levenshtein(ds('a), concat(ds('a),lit("Hello")))) + .select(levenshtein(ds('a), concat(ds('a), lit("Hello")))) .collect() .run() .toVector val cDS = ds.dataset - val aggrTyped = ds.agg( - levenshtein(frameless.functions.aggregate.first(ds('a)), litAggr("Hello")) - ).firstOption().run().get + val aggrTyped = ds + .agg( + levenshtein( + frameless.functions.aggregate.first(ds('a)), + litAggr("Hello") + ) + ) + .firstOption() + .run() + .get - val aggrSpark = cDS.select( - sparkFunctions.levenshtein(sparkFunctions.first("a"), sparkFunctions.lit("Hello")).as[Int] - ).first() + val aggrSpark = cDS + .select( + sparkFunctions + .levenshtein(sparkFunctions.first("a"), sparkFunctions.lit("Hello")) + .as[Int] + ) + .first() (typed ?= sparkResult).&&(aggrTyped ?= aggrSpark) }) @@ -1992,7 +2176,8 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { (values: List[X1[String]], n: Int) => val ds = TypedDataset.create(values.map(x => X1(s"$n${x.a}-$n$n"))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.regexp_replace($"a", "\\d+", "n")) .map(_.getAs[String](0)) .collect() @@ -2014,17 +2199,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.reverse($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(reverse(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(reverse(ds[String]('a))).collect().run().toVector (typed ?= sparkResult).&&(values.map(_.a.reverse).toVector ?= typed) }) @@ -2036,17 +2218,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.rpad($"a", 5, "hello")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(rpad(ds[String]('a), 5, "hello")) - .collect() - .run() - .toVector + val typed = + ds.select(rpad(ds[String]('a), 5, "hello")).collect().run().toVector typed ?= sparkResult }) @@ -2058,17 +2238,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.lpad($"a", 5, "hello")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(lpad(ds[String]('a), 5, "hello")) - .collect() - .run() - .toVector + val typed = + ds.select(lpad(ds[String]('a), 5, "hello")).collect().run().toVector typed ?= sparkResult }) @@ -2080,17 +2258,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.rtrim($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(rtrim(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(rtrim(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2102,17 +2277,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.ltrim($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(ltrim(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(ltrim(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2124,17 +2296,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.substring($"a", 5, 3)) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(substring(ds[String]('a), 5, 3)) - .collect() - .run() - .toVector + val typed = + ds.select(substring(ds[String]('a), 5, 3)).collect().run().toVector typed ?= sparkResult }) @@ -2146,17 +2316,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll { values: List[X1[String]] => val ds = TypedDataset.create(values.map(x => X1(s" ${x.a} "))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.trim($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(trim(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(trim(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2168,17 +2335,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] => val ds = TypedDataset.create(values.map(X1(_))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.upper($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(upper(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(upper(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) @@ -2190,27 +2354,29 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(Gen.listOf(Gen.alphaStr)) { values: List[String] => val ds = TypedDataset.create(values.map(X1(_))) - val sparkResult = ds.toDF() + val sparkResult = ds + .toDF() .select(sparkFunctions.lower($"a")) .map(_.getAs[String](0)) .collect() .toVector - val typed = ds - .select(lower(ds[String]('a))) - .collect() - .run() - .toVector + val typed = ds.select(lower(ds[String]('a))).collect().run().toVector typed ?= sparkResult }) } test("Empty vararg tests") { - def prop[A : TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = { + def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X2[A, B]]) = { val ds = TypedDataset.create(data) - val frameless = ds.select(ds('a), concat(), ds('b), concatWs(":")).collect().run().toVector - val framelessAggr = ds.agg(concat(), concatWs("x"), litAggr(2)).collect().run().toVector + val frameless = ds + .select(ds('a), concat(), ds('b), concatWs(":")) + .collect() + .run() + .toVector + val framelessAggr = + ds.agg(concat(), concatWs("x"), litAggr(2)).collect().run().toVector val scala = data.map(x => (x.a, "", x.b, "")) val scalaAggr = Vector(("", "", 2)) (frameless ?= scala).&&(framelessAggr ?= scalaAggr) @@ -2220,8 +2386,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { check(forAll(prop[Option[Boolean], Long] _)) } - def dateTimeStringProp(typedDS: TypedDataset[X1[String]]) - (typedCol: TypedColumn[X1[String], Option[Int]], sparkFunc: Column => Column): Prop = { + def dateTimeStringProp( + typedDS: TypedDataset[X1[String]] + )(typedCol: TypedColumn[X1[String], Option[Int]], + sparkFunc: Column => Column + ): Prop = { val spark = session import spark.implicits._ @@ -2231,11 +2400,7 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { .collect() .toList - val typed = typedDS - .select(typedCol) - .collect() - .run() - .toList + val typed = typedDS.select(typedCol).collect().run().toList typed ?= sparkResult } @@ -2244,10 +2409,14 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { - val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(year(ds[String]('a)), sparkFunctions.year) - } + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { + val ds = TypedDataset.create(data) + dateTimeStringProp(ds)(year(ds[String]('a)), sparkFunctions.year) + } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) check(forAll(prop _)) @@ -2257,7 +2426,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(quarter(ds[String]('a)), sparkFunctions.quarter) } @@ -2270,7 +2443,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(month(ds[String]('a)), sparkFunctions.month) } @@ -2283,9 +2460,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(dayofweek(ds[String]('a)), sparkFunctions.dayofweek) + dateTimeStringProp(ds)( + dayofweek(ds[String]('a)), + sparkFunctions.dayofweek + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) @@ -2296,9 +2480,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(dayofmonth(ds[String]('a)), sparkFunctions.dayofmonth) + dateTimeStringProp(ds)( + dayofmonth(ds[String]('a)), + sparkFunctions.dayofmonth + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) @@ -2309,9 +2500,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(dayofyear(ds[String]('a)), sparkFunctions.dayofyear) + dateTimeStringProp(ds)( + dayofyear(ds[String]('a)), + sparkFunctions.dayofyear + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) @@ -2322,7 +2520,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(hour(ds[String]('a)), sparkFunctions.hour) } @@ -2335,7 +2537,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(minute(ds[String]('a)), sparkFunctions.minute) } @@ -2348,7 +2554,11 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) dateTimeStringProp(ds)(second(ds[String]('a)), sparkFunctions.second) } @@ -2361,9 +2571,16 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ - def prop(data: List[X1[String]])(implicit E: Encoder[Option[Int]]): Prop = { + def prop( + data: List[X1[String]] + )(implicit + E: Encoder[Option[Int]] + ): Prop = { val ds = TypedDataset.create(data) - dateTimeStringProp(ds)(weekofyear(ds[String]('a)), sparkFunctions.weekofyear) + dateTimeStringProp(ds)( + weekofyear(ds[String]('a)), + sparkFunctions.weekofyear + ) } check(forAll(dateTimeStringGen)(data => prop(data.map(X1.apply)))) diff --git a/dataset/src/test/scala/frameless/functions/UdfTests.scala b/dataset/src/test/scala/frameless/functions/UdfTests.scala index 10e65180f..ed1039640 100644 --- a/dataset/src/test/scala/frameless/functions/UdfTests.scala +++ b/dataset/src/test/scala/frameless/functions/UdfTests.scala @@ -7,14 +7,22 @@ import org.scalacheck.Prop._ class UdfTests extends TypedDatasetSuite { test("one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder](data: Vector[X1[A]], f1: A => B): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder]( + data: Vector[X1[A]], + f1: A => B + ): Prop = { val dataset: TypedDataset[X1[A]] = TypedDataset.create(data) val u1 = udf[X1[A], A, B](f1) val u2 = dataset.makeUDF(f1) val A = dataset.col[A]('a) // filter forces whole codegen - val codegen = dataset.deserialized.filter((_:X1[A]) => true).select(u1(A)).collect().run().toVector + val codegen = dataset.deserialized + .filter((_: X1[A]) => true) + .select(u1(A)) + .collect() + .run() + .toVector // otherwise it uses local relation val local = dataset.select(u2(A)).collect().run().toVector @@ -35,15 +43,22 @@ class UdfTests extends TypedDatasetSuite { check(forAll(prop[Option[Vector[String]], Option[Vector[String]]] _)) - def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = prop(Vector(X1(a)), f) + def prop2[A: TypedEncoder, B: TypedEncoder](f: A => B)(a: A): Prop = + prop(Vector(X1(a)), f) - check(forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _)) + check( + forAll(prop2[Int, Option[Int]](x => if (x % 2 == 0) Some(x) else None) _) + ) check(forAll(prop2[Option[Int], Int](x => x getOrElse 0) _)) } test("multiple one argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: A => A, f2: B => B, f3: C => C): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: A => A, + f2: B => B, + f3: C => C + ): Prop = { val dataset = TypedDataset.create(data) val u11 = udf[X3[A, B, C], A, A](f1) val u21 = udf[X3[A, B, C], B, B](f2) @@ -55,8 +70,10 @@ class UdfTests extends TypedDatasetSuite { val B = dataset.col[B]('b) val C = dataset.col[C]('c) - val dataset21 = dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector - val dataset22 = dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector + val dataset21 = + dataset.select(u11(A), u21(B), u31(C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A), u22(B), u32(C)).collect().run().toVector val d = data.map(x => (f1(x.a), f2(x.b), f3(x.c))) (dataset21 ?= d) && (dataset22 ?= d) @@ -69,8 +86,10 @@ class UdfTests extends TypedDatasetSuite { } test("two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C + ): Prop = { val dataset = TypedDataset.create(data) val u1 = udf[X3[A, B, C], A, B, C](f1) val u2 = dataset.makeUDF(f1) @@ -89,8 +108,11 @@ class UdfTests extends TypedDatasetSuite { } test("multiple two argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f1: (A, B) => C, f2: (B, C) => A): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f1: (A, B) => C, + f2: (B, C) => A + ): Prop = { val dataset = TypedDataset.create(data) val u11 = udf[X3[A, B, C], A, B, C](f1) val u12 = dataset.makeUDF(f1) @@ -101,8 +123,10 @@ class UdfTests extends TypedDatasetSuite { val B = dataset.col[B]('b) val C = dataset.col[C]('c) - val dataset21 = dataset.select(u11(A, B), u21(B, C)).collect().run().toVector - val dataset22 = dataset.select(u12(A, B), u22(B, C)).collect().run().toVector + val dataset21 = + dataset.select(u11(A, B), u21(B, C)).collect().run().toVector + val dataset22 = + dataset.select(u12(A, B), u22(B, C)).collect().run().toVector val d = data.map(x => (f1(x.a, x.b), f2(x.b, x.c))) (dataset21 ?= d) && (dataset22 ?= d) @@ -113,8 +137,10 @@ class UdfTests extends TypedDatasetSuite { } test("three argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder] - (data: Vector[X3[A, B, C]], f: (A, B, C) => C): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3[A, B, C]], + f: (A, B, C) => C + ): Prop = { val dataset = TypedDataset.create(data) val u1 = udf[X3[A, B, C], A, B, C, C](f) val u2 = dataset.makeUDF(f) @@ -135,8 +161,14 @@ class UdfTests extends TypedDatasetSuite { } test("four argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder] - (data: Vector[X4[A, B, C, D]], f: (A, B, C, D) => C): Prop = { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]], + f: (A, B, C, D) => C + ): Prop = { val dataset = TypedDataset.create(data) val u1 = udf[X4[A, B, C, D], A, B, C, D, C](f) val u2 = dataset.makeUDF(f) @@ -161,8 +193,15 @@ class UdfTests extends TypedDatasetSuite { } test("five argument udf") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder, E: TypedEncoder] - (data: Vector[X5[A, B, C, D, E]], f: (A, B, C, D, E) => C): Prop = { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder, + E: TypedEncoder + ](data: Vector[X5[A, B, C, D, E]], + f: (A, B, C, D, E) => C + ): Prop = { val dataset = TypedDataset.create(data) val u1 = udf[X5[A, B, C, D, E], A, B, C, D, E, C](f) val u2 = dataset.makeUDF(f) diff --git a/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala b/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala index 009179be6..23335f519 100644 --- a/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala +++ b/dataset/src/test/scala/frameless/functions/UnaryFunctionsTest.scala @@ -10,7 +10,12 @@ import scala.reflect.ClassTag class UnaryFunctionsTest extends TypedDatasetSuite { test("size tests") { - def prop[F[X] <: Traversable[X] : CatalystSizableCollection, A](xs: List[X1[F[A]]])(implicit arb: Arbitrary[F[A]], enc: TypedEncoder[F[A]]): Prop = { + def prop[F[X] <: Traversable[X]: CatalystSizableCollection, A]( + xs: List[X1[F[A]]] + )(implicit + arb: Arbitrary[F[A]], + enc: TypedEncoder[F[A]] + ): Prop = { val tds = TypedDataset.create(xs) val framelessResults = tds.select(size(tds('a))).collect().run().toVector @@ -43,7 +48,12 @@ class UnaryFunctionsTest extends TypedDatasetSuite { } test("size on Map") { - def prop[A](xs: List[X1[Map[A, A]]])(implicit arb: Arbitrary[Map[A, A]], enc: TypedEncoder[Map[A, A]]): Prop = { + def prop[A]( + xs: List[X1[Map[A, A]]] + )(implicit + arb: Arbitrary[Map[A, A]], + enc: TypedEncoder[Map[A, A]] + ): Prop = { val tds = TypedDataset.create(xs) val framelessResults = tds.select(size(tds('a))).collect().run().toVector @@ -58,10 +68,15 @@ class UnaryFunctionsTest extends TypedDatasetSuite { } test("sort in ascending order") { - def prop[F[X] <: SeqLike[X, F[X]] : CatalystSortableCollection, A: Ordering](xs: List[X1[F[A]]])(implicit enc: TypedEncoder[F[A]]): Prop = { + def prop[F[X] <: SeqLike[X, F[X]]: CatalystSortableCollection, A: Ordering]( + xs: List[X1[F[A]]] + )(implicit + enc: TypedEncoder[F[A]] + ): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(sortAscending(tds('a))).collect().run().toVector + val framelessResults = + tds.select(sortAscending(tds('a))).collect().run().toVector val scalaResults = xs.map(x => x.a.sorted).toVector framelessResults ?= scalaResults @@ -78,10 +93,15 @@ class UnaryFunctionsTest extends TypedDatasetSuite { } test("sort in descending order") { - def prop[F[X] <: SeqLike[X, F[X]] : CatalystSortableCollection, A: Ordering](xs: List[X1[F[A]]])(implicit enc: TypedEncoder[F[A]]): Prop = { + def prop[F[X] <: SeqLike[X, F[X]]: CatalystSortableCollection, A: Ordering]( + xs: List[X1[F[A]]] + )(implicit + enc: TypedEncoder[F[A]] + ): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(sortDescending(tds('a))).collect().run().toVector + val framelessResults = + tds.select(sortDescending(tds('a))).collect().run().toVector val scalaResults = xs.map(x => x.a.sorted.reverse).toVector framelessResults ?= scalaResults @@ -98,18 +118,19 @@ class UnaryFunctionsTest extends TypedDatasetSuite { } test("sort on array test: ascending order") { - def prop[A: TypedEncoder : Ordering : ClassTag](xs: List[X1[Array[A]]]): Prop = { + def prop[A: TypedEncoder: Ordering: ClassTag]( + xs: List[X1[Array[A]]] + ): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(sortAscending(tds('a))).collect().run().toVector + val framelessResults = + tds.select(sortAscending(tds('a))).collect().run().toVector val scalaResults = xs.map(x => x.a.sorted).toVector Prop { - framelessResults - .zip(scalaResults) - .forall { - case (a, b) => a sameElements b - } + framelessResults.zip(scalaResults).forall { + case (a, b) => a sameElements b + } } } @@ -119,18 +140,19 @@ class UnaryFunctionsTest extends TypedDatasetSuite { } test("sort on array test: descending order") { - def prop[A: TypedEncoder : Ordering : ClassTag](xs: List[X1[Array[A]]]): Prop = { + def prop[A: TypedEncoder: Ordering: ClassTag]( + xs: List[X1[Array[A]]] + ): Prop = { val tds = TypedDataset.create(xs) - val framelessResults = tds.select(sortDescending(tds('a))).collect().run().toVector + val framelessResults = + tds.select(sortDescending(tds('a))).collect().run().toVector val scalaResults = xs.map(x => x.a.sorted.reverse).toVector Prop { - framelessResults - .zip(scalaResults) - .forall { - case (a, b) => a sameElements b - } + framelessResults.zip(scalaResults).forall { + case (a, b) => a sameElements b + } } } diff --git a/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala b/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala index 303eb2cbd..df094c7a2 100644 --- a/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala +++ b/dataset/src/test/scala/frameless/ops/ColumnTypesTest.scala @@ -8,20 +8,30 @@ import shapeless.:: class ColumnTypesTest extends TypedDatasetSuite { test("test summoning") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]] + ): Prop = { val d: TypedDataset[X4[A, B, C, D]] = TypedDataset.create(data) val hlist = d('a) :: d('b) :: d('c) :: d('d) :: HNil - type TC[N] = TypedColumn[X4[A,B,C,D], N] + type TC[N] = TypedColumn[X4[A, B, C, D], N] type IN = TC[A] :: TC[B] :: TC[C] :: TC[D] :: HNil type OUT = A :: B :: C :: D :: HNil - implicitly[ColumnTypes.Aux[X4[A,B,C,D], IN, OUT]] + implicitly[ColumnTypes.Aux[X4[A, B, C, D], IN, OUT]] Prop.passed // successful compilation implies test correctness } check(forAll(prop[Int, String, X1[String], Boolean] _)) - check(forAll(prop[Vector[Int], Vector[Vector[String]], X1[String], Option[String]] _)) + check( + forAll( + prop[Vector[Int], Vector[Vector[String]], X1[String], Option[String]] _ + ) + ) } } diff --git a/dataset/src/test/scala/frameless/ops/CubeTests.scala b/dataset/src/test/scala/frameless/ops/CubeTests.scala index 7a06822b9..ae4c72d69 100644 --- a/dataset/src/test/scala/frameless/ops/CubeTests.scala +++ b/dataset/src/test/scala/frameless/ops/CubeTests.scala @@ -8,14 +8,23 @@ import org.scalacheck.Prop._ class CubeTests extends TypedDatasetSuite { test("cube('a).agg(count())") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.cube(A).agg(count()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.cube("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = + dataset.cube(A).agg(count()).collect().run().toVector.sortBy(_._2) + val expected = dataset.dataset + .cube("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(_._2) received ?= expected } @@ -24,15 +33,29 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a, 'b).agg(count())") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.cube(A, B).agg(count()).collect().run().toVector.sortBy(_._3) - val expected = dataset.dataset.cube("a", "b").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))).sortBy(_._3) + val received = + dataset.cube(A, B).agg(count()).collect().run().toVector.sortBy(_._3) + val expected = dataset.dataset + .cube("a", "b") + .count() + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2)) + ) + .sortBy(_._3) received ?= expected } @@ -41,15 +64,27 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).agg(sum('b)") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.cube(A).agg(sum(B)).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.cube("a").sum("b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))).sortBy(_._2) + val received = + dataset.cube(A).agg(sum(B)).collect().run().toVector.sortBy(_._2) + val expected = dataset.dataset + .cube("a") + .sum("b") + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))) + .sortBy(_._2) received ?= expected } @@ -58,15 +93,22 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).mapGroups('a, sum('b))") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder : Numeric] - (data: List[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric]( + data: List[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.cube(A) - .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } - .collect().run().toVector.sortBy(_._1) - val expected = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) + val received = dataset + .cube(A) + .deserialized + .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } + .collect() + .run() + .toVector + .sortBy(_._1) + val expected = + data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) received ?= expected } @@ -76,16 +118,16 @@ class CubeTests extends TypedDatasetSuite { test("cube('a).agg(sum('b), sum('c)) to cube('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder, - C: TypedEncoder, - OutB: TypedEncoder : Numeric, - OutC: TypedEncoder : Numeric - ](data: List[X3[A, B, C]])( - implicit - summableB: CatalystSummable[B, OutB], - summableC: CatalystSummable[C, OutC] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder, + C: TypedEncoder, + OutB: TypedEncoder: Numeric, + OutC: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + )(implicit + summableB: CatalystSummable[B, OutB], + summableC: CatalystSummable[C, OutC] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -94,37 +136,91 @@ class CubeTests extends TypedDatasetSuite { val framelessSumBC = dataset .cube(A) .agg(sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBC = dataset.dataset.cube("a").sum("b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2))) + val sparkSumBC = dataset.dataset + .cube("a") + .sum("b", "c") + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2)) + ) .sortBy(_._1) val framelessSumBCB = dataset .cube(A) .agg(sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBCB = dataset.dataset.cube("a").sum("b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3))) + val sparkSumBCB = dataset.dataset + .cube("a") + .sum("b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + row.getAs[OutC](2), + row.getAs[OutB](3) + ) + ) .sortBy(_._1) val framelessSumBCBC = dataset .cube(A) .agg(sum(B), sum(C), sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBCBC = dataset.dataset.cube("a").sum("b", "c", "b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4))) + val sparkSumBCBC = dataset.dataset + .cube("a") + .sum("b", "c", "b", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + row.getAs[OutC](2), + row.getAs[OutB](3), + row.getAs[OutC](4) + ) + ) .sortBy(_._1) val framelessSumBCBCB = dataset .cube(A) .agg(sum(B), sum(C), sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBCBCB = dataset.dataset.cube("a").sum("b", "c", "b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4), row.getAs[OutB](5))) + val sparkSumBCBCB = dataset.dataset + .cube("a") + .sum("b", "c", "b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + row.getAs[OutC](2), + row.getAs[OutB](3), + row.getAs[OutC](4), + row.getAs[OutB](5) + ) + ) .sortBy(_._1) (framelessSumBC ?= sparkSumBC) @@ -138,17 +234,17 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).agg(sum('c), sum('d))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - D: TypedEncoder, - OutC: TypedEncoder : Numeric, - OutD: TypedEncoder : Numeric - ](data: List[X4[A, B, C, D]])( - implicit - summableC: CatalystSummable[C, OutC], - summableD: CatalystSummable[D, OutD] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + D: TypedEncoder, + OutC: TypedEncoder: Numeric, + OutD: TypedEncoder: Numeric + ](data: List[X4[A, B, C, D]] + )(implicit + summableC: CatalystSummable[C, OutC], + summableD: CatalystSummable[D, OutD] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -158,11 +254,24 @@ class CubeTests extends TypedDatasetSuite { val framelessSumByAB = dataset .cube(A, B) .agg(sum(C), sum(D)) - .collect().run().toVector.sortBy(x => (x._1, x._2)) + .collect() + .run() + .toVector + .sortBy(x => (x._1, x._2)) val sparkSumByAB = dataset.dataset - .cube("a", "b").sum("c", "d").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutD](3))) + .cube("a", "b") + .sum("c", "d") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutD](3) + ) + ) .sortBy(x => (x._1, x._2)) framelessSumByAB ?= sparkSumByAB @@ -173,76 +282,135 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).agg(sum('c)) to cube('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - OutC: TypedEncoder: Numeric - ](data: List[X3[A, B, C]])(implicit summableC: CatalystSummable[C, OutC]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + OutC: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + )(implicit + summableC: CatalystSummable[C, OutC] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) - val framelessSumC = dataset - .cube(A, B) - .agg(sum(C)) - .collect().run().toVector - .sortBy(_._2) + val framelessSumC = + dataset.cube(A, B).agg(sum(C)).collect().run().toVector.sortBy(_._2) val sparkSumC = dataset.dataset - .cube("a", "b").sum("c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2))) + .cube("a", "b") + .sum("c") + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2)) + ) .sortBy(_._2) val framelessSumCC = dataset .cube(A, B) .agg(sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCC = dataset.dataset - .cube("a", "b").sum("c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3))) + .cube("a", "b") + .sum("c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3) + ) + ) .sortBy(_._2) val framelessSumCCC = dataset .cube(A, B) .agg(sum(C), sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCCC = dataset.dataset - .cube("a", "b").sum("c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4))) + .cube("a", "b") + .sum("c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3), + row.getAs[OutC](4) + ) + ) .sortBy(_._2) val framelessSumCCCC = dataset .cube(A, B) .agg(sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCCCC = dataset.dataset - .cube("a", "b").sum("c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5))) + .cube("a", "b") + .sum("c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3), + row.getAs[OutC](4), + row.getAs[OutC](5) + ) + ) .sortBy(_._2) val framelessSumCCCCC = dataset .cube(A, B) .agg(sum(C), sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCCCCC = dataset.dataset - .cube("a", "b").sum("c", "c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5), row.getAs[OutC](6))) + .cube("a", "b") + .sum("c", "c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3), + row.getAs[OutC](4), + row.getAs[OutC](5), + row.getAs[OutC](6) + ) + ) .sortBy(_._2) (framelessSumC ?= sparkSumC) && - (framelessSumCC ?= sparkSumCC) && - (framelessSumCCC ?= sparkSumCCC) && - (framelessSumCCCC ?= sparkSumCCCC) && - (framelessSumCCCCC ?= sparkSumCCCCC) + (framelessSumCC ?= sparkSumCC) && + (framelessSumCCC ?= sparkSumCCC) && + (framelessSumCCCC ?= sparkSumCCCC) && + (framelessSumCCCCC ?= sparkSumCCCCC) } check(forAll(prop[String, Long, Double, Double] _)) @@ -250,22 +418,30 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).mapGroups('a, 'b, sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Numeric - ](data: List[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val framelessSumByAB = dataset .cube(A, B) - .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } - .collect().run().toVector.sortBy(x => (x._1, x._2)) + .deserialized + .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } + .collect() + .run() + .toVector + .sortBy(x => (x._1, x._2)) - val sumByAB = data.groupBy(x => (x.a, x.b)) + val sumByAB = data + .groupBy(x => (x.a, x.b)) .mapValues { xs => xs.map(_.c).sum } - .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2)) + .toVector + .map { case ((a, b), c) => (a, b, c) } + .sortBy(x => (x._1, x._2)) framelessSumByAB ?= sumByAB } @@ -274,17 +450,19 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).mapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder: Ordering, - B: TypedEncoder: Ordering, - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .cube(A) - .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted)) - .collect().run().toMap + .deserialized + .mapGroups((a, xs) => (a, xs.toVector.sorted)) + .collect() + .run() + .toMap val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted } @@ -297,21 +475,23 @@ class CubeTests extends TypedDatasetSuite { } test("cube('a).flatMapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .cube(A) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(_.a).toSeq + .groupBy(_.a) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -325,22 +505,26 @@ class CubeTests extends TypedDatasetSuite { test("cube('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Ordering - ](data: Vector[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](data: Vector[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val cA = dataset.col[A]('a) val cB = dataset.col[B]('b) val datasetGrouped = dataset .cube(cA, cB) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(t => (t.a, t.b)).toSeq + .groupBy(t => (t.a, t.b)) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -353,18 +537,32 @@ class CubeTests extends TypedDatasetSuite { } test("cubeMany('a).agg(sum('b))") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.cubeMany(A).agg(count[X1[A]]()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.cube("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = dataset + .cubeMany(A) + .agg(count[X1[A]]()) + .collect() + .run() + .toVector + .sortBy(_._2) + val expected = dataset.dataset + .cube("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(_._2) received ?= expected } check(forAll(prop[Int, Long] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/ops/PivotTest.scala b/dataset/src/test/scala/frameless/ops/PivotTest.scala index dd9bf5e61..f3f2bbcf9 100644 --- a/dataset/src/test/scala/frameless/ops/PivotTest.scala +++ b/dataset/src/test/scala/frameless/ops/PivotTest.scala @@ -2,12 +2,13 @@ package frameless package ops import frameless.functions.aggregate._ -import org.apache.spark.sql.{functions => sparkFunctions} +import org.apache.spark.sql.{ functions => sparkFunctions } import org.scalacheck.Arbitrary.arbitrary import org.scalacheck.Prop._ -import org.scalacheck.{Gen, Prop} +import org.scalacheck.{ Gen, Prop } class PivotTest extends TypedDatasetSuite { + def withCustomGenX4: Gen[Vector[X4[String, String, Int, Boolean]]] = { val kvPairGen: Gen[X4[String, String, Int, Boolean]] = for { a <- Gen.oneOf(Seq("1", "2", "3", "4")) @@ -22,77 +23,109 @@ class PivotTest extends TypedDatasetSuite { test("X4[Boolean, String, Int, Boolean] pivot on String") { def prop(data: Vector[X4[String, String, Int, Boolean]]): Prop = { val d = TypedDataset.create(data) - val frameless = d.groupBy(d('a)). - pivot(d('b)).on("a", "b", "c"). - agg(sum(d('c)), first(d('d))).collect().run().toVector + val frameless = d + .groupBy(d('a)) + .pivot(d('b)) + .on("a", "b", "c") + .agg(sum(d('c)), first(d('d))) + .collect() + .run() + .toVector - val spark = d.dataset.groupBy("a") + val spark = d.dataset + .groupBy("a") .pivot("b", Seq("a", "b", "c")) - .agg(sparkFunctions.sum("c"), sparkFunctions.first("d")).collect().toVector + .agg(sparkFunctions.sum("c"), sparkFunctions.first("d")) + .collect() + .toVector - (frameless.map(_._1) ?= spark.map(x => x.getAs[String](0))).&&( - frameless.map(_._2) ?= spark.map(x => Option(x.getAs[Long](1)))).&&( - frameless.map(_._3) ?= spark.map(x => Option(x.getAs[Boolean](2)))).&&( - frameless.map(_._4) ?= spark.map(x => Option(x.getAs[Long](3)))).&&( - frameless.map(_._5) ?= spark.map(x => Option(x.getAs[Boolean](4)))).&&( - frameless.map(_._6) ?= spark.map(x => Option(x.getAs[Long](5)))).&&( - frameless.map(_._7) ?= spark.map(x => Option(x.getAs[Boolean](6)))) + (frameless.map(_._1) ?= spark.map(x => x.getAs[String](0))) + .&&(frameless.map(_._2) ?= spark.map(x => Option(x.getAs[Long](1)))) + .&&(frameless.map(_._3) ?= spark.map(x => Option(x.getAs[Boolean](2)))) + .&&(frameless.map(_._4) ?= spark.map(x => Option(x.getAs[Long](3)))) + .&&(frameless.map(_._5) ?= spark.map(x => Option(x.getAs[Boolean](4)))) + .&&(frameless.map(_._6) ?= spark.map(x => Option(x.getAs[Long](5)))) + .&&(frameless.map(_._7) ?= spark.map(x => Option(x.getAs[Boolean](6)))) } check(forAll(withCustomGenX4)(prop)) } test("Pivot on Boolean") { - val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) + val x: Seq[X3[String, Boolean, Boolean]] = + Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) val d = TypedDataset.create(x) - d.groupByMany(d('a)). - pivot(d('c)).on(true, false). - agg(count[X3[String, Boolean, Boolean]]()). - collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) // two true one false + d.groupByMany(d('a)) + .pivot(d('c)) + .on(true, false) + .agg(count[X3[String, Boolean, Boolean]]()) + .collect() + .run() + .toVector ?= Vector(("a", Some(2L), Some(1L))) // two true one false } test("Pivot with groupBy on two columns, pivot on Long") { - val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) + val x: Seq[X3[String, String, Long]] = + Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) val d = TypedDataset.create(x) - d.groupBy(d('a), d('b)). - pivot(d('c)).on(1L, 20L). - agg(count[X3[String, String, Long]]()). - collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) + d.groupBy(d('a), d('b)) + .pivot(d('c)) + .on(1L, 20L) + .agg(count[X3[String, String, Long]]()) + .collect() + .run() + .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) } test("Pivot with cube on two columns, pivot on Long") { - val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) + val x: Seq[X3[String, String, Long]] = + Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) val d = TypedDataset.create(x) d.cube(d('a), d('b)) - .pivot(d('c)).on(1L, 20L) + .pivot(d('c)) + .on(1L, 20L) .agg(count[X3[String, String, Long]]()) - .collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) + .collect() + .run() + .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) } test("Pivot with cube on Boolean") { - val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) + val x: Seq[X3[String, Boolean, Boolean]] = + Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) val d = TypedDataset.create(x) - d.cube(d('a)). - pivot(d('c)).on(true, false). - agg(count[X3[String, Boolean, Boolean]]()). - collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) + d.cube(d('a)) + .pivot(d('c)) + .on(true, false) + .agg(count[X3[String, Boolean, Boolean]]()) + .collect() + .run() + .toVector ?= Vector(("a", Some(2L), Some(1L))) } test("Pivot with rollup on two columns, pivot on Long") { - val x: Seq[X3[String, String, Long]] = Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) + val x: Seq[X3[String, String, Long]] = + Seq(X3("a", "x", 1), X3("a", "x", 1), X3("a", "c", 20)) val d = TypedDataset.create(x) d.rollup(d('a), d('b)) - .pivot(d('c)).on(1L, 20L) + .pivot(d('c)) + .on(1L, 20L) .agg(count[X3[String, String, Long]]()) - .collect().run().toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) + .collect() + .run() + .toSet ?= Set(("a", "x", Some(2L), None), ("a", "c", None, Some(1L))) } test("Pivot with rollup on Boolean") { - val x: Seq[X3[String, Boolean, Boolean]] = Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) + val x: Seq[X3[String, Boolean, Boolean]] = + Seq(X3("a", true, true), X3("a", true, true), X3("a", true, false)) val d = TypedDataset.create(x) - d.rollupMany(d('a)). - pivot(d('c)).on(true, false). - agg(count[X3[String, Boolean, Boolean]]()). - collect().run().toVector ?= Vector(("a", Some(2L), Some(1L))) + d.rollupMany(d('a)) + .pivot(d('c)) + .on(true, false) + .agg(count[X3[String, Boolean, Boolean]]()) + .collect() + .run() + .toVector ?= Vector(("a", Some(2L), Some(1L))) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/ops/RepeatTest.scala b/dataset/src/test/scala/frameless/ops/RepeatTest.scala index 78dfc6410..0dd6905fb 100644 --- a/dataset/src/test/scala/frameless/ops/RepeatTest.scala +++ b/dataset/src/test/scala/frameless/ops/RepeatTest.scala @@ -2,17 +2,31 @@ package frameless package ops import shapeless.test.illTyped -import shapeless.{::, HNil, Nat} +import shapeless.{ ::, HNil, Nat } class RepeatTest extends TypedDatasetSuite { test("summoning with implicitly") { - implicitly[Repeat.Aux[Int::Boolean::HNil, Nat._1, Int::Boolean::HNil]] - implicitly[Repeat.Aux[Int::Boolean::HNil, Nat._2, Int::Boolean::Int::Boolean::HNil]] - implicitly[Repeat.Aux[Int::Boolean::HNil, Nat._3, Int::Boolean::Int::Boolean::Int::Boolean::HNil]] - implicitly[Repeat.Aux[String::HNil, Nat._5, String::String::String::String::String::HNil]] + implicitly[ + Repeat.Aux[Int :: Boolean :: HNil, Nat._1, Int :: Boolean :: HNil] + ] + implicitly[Repeat.Aux[ + Int :: Boolean :: HNil, + Nat._2, + Int :: Boolean :: Int :: Boolean :: HNil + ]] + implicitly[Repeat.Aux[ + Int :: Boolean :: HNil, + Nat._3, + Int :: Boolean :: Int :: Boolean :: Int :: Boolean :: HNil + ]] + implicitly[Repeat.Aux[ + String :: HNil, + Nat._5, + String :: String :: String :: String :: String :: HNil + ]] } test("ill typed") { illTyped("""implicitly[Repeat.Aux[String::HNil, Nat._5, String::String::String::String::HNil]]""") } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/ops/RollupTests.scala b/dataset/src/test/scala/frameless/ops/RollupTests.scala index da73ef8d0..cdccba47d 100644 --- a/dataset/src/test/scala/frameless/ops/RollupTests.scala +++ b/dataset/src/test/scala/frameless/ops/RollupTests.scala @@ -8,14 +8,23 @@ import org.scalacheck.Prop._ class RollupTests extends TypedDatasetSuite { test("rollup('a).agg(count())") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.rollup(A).agg(count()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.rollup("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = + dataset.rollup(A).agg(count()).collect().run().toVector.sortBy(_._2) + val expected = dataset.dataset + .rollup("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(_._2) received ?= expected } @@ -24,15 +33,29 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a, 'b).agg(count())") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.rollup(A, B).agg(count()).collect().run().toVector.sortBy(_._3) - val expected = dataset.dataset.rollup("a", "b").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2))).sortBy(_._3) + val received = + dataset.rollup(A, B).agg(count()).collect().run().toVector.sortBy(_._3) + val expected = dataset.dataset + .rollup("a", "b") + .count() + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[Long](2)) + ) + .sortBy(_._3) received ?= expected } @@ -41,15 +64,27 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).agg(sum('b)") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder, Out: TypedEncoder : Numeric] - (data: List[X2[A, B]])(implicit summable: CatalystSummable[B, Out]): Prop = { + def prop[ + A: TypedEncoder: Ordering, + B: TypedEncoder, + Out: TypedEncoder: Numeric + ](data: List[X2[A, B]] + )(implicit + summable: CatalystSummable[B, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) - val received = dataset.rollup(A).agg(sum(B)).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.rollup("a").sum("b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))).sortBy(_._2) + val received = + dataset.rollup(A).agg(sum(B)).collect().run().toVector.sortBy(_._2) + val expected = dataset.dataset + .rollup("a") + .sum("b") + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Out](1))) + .sortBy(_._2) received ?= expected } @@ -58,15 +93,22 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).mapGroups('a, sum('b))") { - def prop[A: TypedEncoder : Ordering, B: TypedEncoder : Numeric] - (data: List[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Numeric]( + data: List[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.rollup(A) - .deserialized.mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } - .collect().run().toVector.sortBy(_._1) - val expected = data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) + val received = dataset + .rollup(A) + .deserialized + .mapGroups { case (a, xs) => (a, xs.map(_.b).sum) } + .collect() + .run() + .toVector + .sortBy(_._1) + val expected = + data.groupBy(_.a).mapValues(_.map(_.b).sum).toVector.sortBy(_._1) received ?= expected } @@ -76,16 +118,16 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a).agg(sum('b), sum('c)) to rollup('a).agg(sum('a), sum('b), sum('a), sum('b), sum('a))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder, - C: TypedEncoder, - OutB: TypedEncoder : Numeric, - OutC: TypedEncoder : Numeric - ](data: List[X3[A, B, C]])( - implicit - summableB: CatalystSummable[B, OutB], - summableC: CatalystSummable[C, OutC] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder, + C: TypedEncoder, + OutB: TypedEncoder: Numeric, + OutC: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + )(implicit + summableB: CatalystSummable[B, OutB], + summableC: CatalystSummable[C, OutC] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -94,37 +136,91 @@ class RollupTests extends TypedDatasetSuite { val framelessSumBC = dataset .rollup(A) .agg(sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBC = dataset.dataset.rollup("a").sum("b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2))) + val sparkSumBC = dataset.dataset + .rollup("a") + .sum("b", "c") + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2)) + ) .sortBy(_._1) val framelessSumBCB = dataset .rollup(A) .agg(sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBCB = dataset.dataset.rollup("a").sum("b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3))) + val sparkSumBCB = dataset.dataset + .rollup("a") + .sum("b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + row.getAs[OutC](2), + row.getAs[OutB](3) + ) + ) .sortBy(_._1) val framelessSumBCBC = dataset .rollup(A) .agg(sum(B), sum(C), sum(B), sum(C)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBCBC = dataset.dataset.rollup("a").sum("b", "c", "b", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4))) + val sparkSumBCBC = dataset.dataset + .rollup("a") + .sum("b", "c", "b", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + row.getAs[OutC](2), + row.getAs[OutB](3), + row.getAs[OutC](4) + ) + ) .sortBy(_._1) val framelessSumBCBCB = dataset .rollup(A) .agg(sum(B), sum(C), sum(B), sum(C), sum(B)) - .collect().run().toVector.sortBy(_._1) + .collect() + .run() + .toVector + .sortBy(_._1) - val sparkSumBCBCB = dataset.dataset.rollup("a").sum("b", "c", "b", "c", "b").collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[OutB](1), row.getAs[OutC](2), row.getAs[OutB](3), row.getAs[OutC](4), row.getAs[OutB](5))) + val sparkSumBCBCB = dataset.dataset + .rollup("a") + .sum("b", "c", "b", "c", "b") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + row.getAs[OutB](1), + row.getAs[OutC](2), + row.getAs[OutB](3), + row.getAs[OutC](4), + row.getAs[OutB](5) + ) + ) .sortBy(_._1) (framelessSumBC ?= sparkSumBC) @@ -138,17 +234,17 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).agg(sum('c), sum('d))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - D: TypedEncoder, - OutC: TypedEncoder : Numeric, - OutD: TypedEncoder : Numeric - ](data: List[X4[A, B, C, D]])( - implicit - summableC: CatalystSummable[C, OutC], - summableD: CatalystSummable[D, OutD] - ): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + D: TypedEncoder, + OutC: TypedEncoder: Numeric, + OutD: TypedEncoder: Numeric + ](data: List[X4[A, B, C, D]] + )(implicit + summableC: CatalystSummable[C, OutC], + summableD: CatalystSummable[D, OutD] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) @@ -158,11 +254,24 @@ class RollupTests extends TypedDatasetSuite { val framelessSumByAB = dataset .rollup(A, B) .agg(sum(C), sum(D)) - .collect().run().toVector.sortBy(_._2) + .collect() + .run() + .toVector + .sortBy(_._2) val sparkSumByAB = dataset.dataset - .rollup("a", "b").sum("c", "d").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutD](3))) + .rollup("a", "b") + .sum("c", "d") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutD](3) + ) + ) .sortBy(_._2) framelessSumByAB ?= sparkSumByAB @@ -173,76 +282,135 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).agg(sum('c)) to rollup('a, 'b).agg(sum('c),sum('c),sum('c),sum('c),sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder, - OutC: TypedEncoder: Numeric - ](data: List[X3[A, B, C]])(implicit summableC: CatalystSummable[C, OutC]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder, + OutC: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + )(implicit + summableC: CatalystSummable[C, OutC] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val C = dataset.col[C]('c) - val framelessSumC = dataset - .rollup(A, B) - .agg(sum(C)) - .collect().run().toVector - .sortBy(_._2) + val framelessSumC = + dataset.rollup(A, B).agg(sum(C)).collect().run().toVector.sortBy(_._2) val sparkSumC = dataset.dataset - .rollup("a", "b").sum("c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2))) + .rollup("a", "b") + .sum("c") + .collect() + .toVector + .map(row => + (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2)) + ) .sortBy(_._2) val framelessSumCC = dataset .rollup(A, B) .agg(sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCC = dataset.dataset - .rollup("a", "b").sum("c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3))) + .rollup("a", "b") + .sum("c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3) + ) + ) .sortBy(_._2) val framelessSumCCC = dataset .rollup(A, B) .agg(sum(C), sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCCC = dataset.dataset - .rollup("a", "b").sum("c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4))) + .rollup("a", "b") + .sum("c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3), + row.getAs[OutC](4) + ) + ) .sortBy(_._2) val framelessSumCCCC = dataset .rollup(A, B) .agg(sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCCCC = dataset.dataset - .rollup("a", "b").sum("c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5))) + .rollup("a", "b") + .sum("c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3), + row.getAs[OutC](4), + row.getAs[OutC](5) + ) + ) .sortBy(_._2) val framelessSumCCCCC = dataset .rollup(A, B) .agg(sum(C), sum(C), sum(C), sum(C), sum(C)) - .collect().run().toVector + .collect() + .run() + .toVector .sortBy(_._2) val sparkSumCCCCC = dataset.dataset - .rollup("a", "b").sum("c", "c", "c", "c", "c").collect().toVector - .map(row => (Option(row.getAs[A](0)), Option(row.getAs[B](1)), row.getAs[OutC](2), row.getAs[OutC](3), row.getAs[OutC](4), row.getAs[OutC](5), row.getAs[OutC](6))) + .rollup("a", "b") + .sum("c", "c", "c", "c", "c") + .collect() + .toVector + .map(row => + ( + Option(row.getAs[A](0)), + Option(row.getAs[B](1)), + row.getAs[OutC](2), + row.getAs[OutC](3), + row.getAs[OutC](4), + row.getAs[OutC](5), + row.getAs[OutC](6) + ) + ) .sortBy(_._2) (framelessSumC ?= sparkSumC) && - (framelessSumCC ?= sparkSumCC) && - (framelessSumCCC ?= sparkSumCCC) && - (framelessSumCCCC ?= sparkSumCCCC) && - (framelessSumCCCCC ?= sparkSumCCCCC) + (framelessSumCC ?= sparkSumCC) && + (framelessSumCCC ?= sparkSumCCC) && + (framelessSumCCCC ?= sparkSumCCCC) && + (framelessSumCCCCC ?= sparkSumCCCCC) } check(forAll(prop[String, Long, Double, Double] _)) @@ -250,22 +418,30 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).mapGroups('a, 'b, sum('c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Numeric - ](data: List[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Numeric + ](data: List[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val B = dataset.col[B]('b) val framelessSumByAB = dataset .rollup(A, B) - .deserialized.mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } - .collect().run().toVector.sortBy(x => (x._1, x._2)) - - val sumByAB = data.groupBy(x => (x.a, x.b)) + .deserialized + .mapGroups { case ((a, b), xs) => (a, b, xs.map(_.c).sum) } + .collect() + .run() + .toVector + .sortBy(x => (x._1, x._2)) + + val sumByAB = data + .groupBy(x => (x.a, x.b)) .mapValues { xs => xs.map(_.c).sum } - .toVector.map { case ((a, b), c) => (a, b, c) }.sortBy(x => (x._1, x._2)) + .toVector + .map { case ((a, b), c) => (a, b, c) } + .sortBy(x => (x._1, x._2)) framelessSumByAB ?= sumByAB } @@ -274,17 +450,19 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).mapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder: Ordering, - B: TypedEncoder: Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .rollup(A) - .deserialized.mapGroups((a, xs) => (a, xs.toVector.sorted)) - .collect().run().toMap + .deserialized + .mapGroups((a, xs) => (a, xs.toVector.sorted)) + .collect() + .run() + .toMap val dataGrouped = data.groupBy(_.a).map { case (k, v) => k -> v.sorted } @@ -297,21 +475,23 @@ class RollupTests extends TypedDatasetSuite { } test("rollup('a).flatMapGroups(('a, toVector(('a, 'b))") { - def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering - ](data: Vector[X2[A, B]]): Prop = { + def prop[A: TypedEncoder: Ordering, B: TypedEncoder: Ordering]( + data: Vector[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) val datasetGrouped = dataset .rollup(A) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(_.a).toSeq + .groupBy(_.a) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -325,22 +505,26 @@ class RollupTests extends TypedDatasetSuite { test("rollup('a, 'b).flatMapGroups((('a,'b) toVector((('a,'b), 'c))") { def prop[ - A: TypedEncoder : Ordering, - B: TypedEncoder : Ordering, - C: TypedEncoder : Ordering - ](data: Vector[X3[A, B, C]]): Prop = { + A: TypedEncoder: Ordering, + B: TypedEncoder: Ordering, + C: TypedEncoder: Ordering + ](data: Vector[X3[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) val cA = dataset.col[A]('a) val cB = dataset.col[B]('b) val datasetGrouped = dataset .rollup(cA, cB) - .deserialized.flatMapGroups((a, xs) => xs.map(x => (a, x))) - .collect().run() + .deserialized + .flatMapGroups((a, xs) => xs.map(x => (a, x))) + .collect() + .run() .sorted val dataGrouped = data - .groupBy(t => (t.a, t.b)).toSeq + .groupBy(t => (t.a, t.b)) + .toSeq .flatMap { case (a, xs) => xs.map(x => (a, x)) } .sorted @@ -353,18 +537,32 @@ class RollupTests extends TypedDatasetSuite { } test("rollupMany('a).agg(sum('b))") { - def prop[A: TypedEncoder : Ordering, Out: TypedEncoder : Numeric] - (data: List[X1[A]])(implicit summable: CatalystSummable[A, Out]): Prop = { + def prop[A: TypedEncoder: Ordering, Out: TypedEncoder: Numeric]( + data: List[X1[A]] + )(implicit + summable: CatalystSummable[A, Out] + ): Prop = { val dataset = TypedDataset.create(data) val A = dataset.col[A]('a) - val received = dataset.rollupMany(A).agg(count[X1[A]]()).collect().run().toVector.sortBy(_._2) - val expected = dataset.dataset.rollup("a").count().collect().toVector - .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))).sortBy(_._2) + val received = dataset + .rollupMany(A) + .agg(count[X1[A]]()) + .collect() + .run() + .toVector + .sortBy(_._2) + val expected = dataset.dataset + .rollup("a") + .count() + .collect() + .toVector + .map(row => (Option(row.getAs[A](0)), row.getAs[Long](1))) + .sortBy(_._2) received ?= expected } check(forAll(prop[Int, Long] _)) } -} \ No newline at end of file +} diff --git a/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala b/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala index 233a42aec..8b507c1a5 100644 --- a/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala +++ b/dataset/src/test/scala/frameless/ops/SmartProjectTest.scala @@ -5,15 +5,16 @@ import org.scalacheck.Prop import org.scalacheck.Prop._ import shapeless.test.illTyped - case class Foo(i: Int, j: Int, x: String) case class Bar(i: Int, x: String) case class InvalidFooProjectionType(i: Int, x: Boolean) case class InvalidFooProjectionName(i: Int, xerr: String) class SmartProjectTest extends TypedDatasetSuite { + // Lazy needed to prevent initialization anterior to the `beforeAll` hook - lazy val dataset = TypedDataset.create(Foo(1, 2, "hi") :: Foo(2, 3, "there") :: Nil) + lazy val dataset = + TypedDataset.create(Foo(1, 2, "hi") :: Foo(2, 3, "there") :: Nil) test("project Foo to Bar") { assert(dataset.project[Bar].count().run() === 2) @@ -25,28 +26,50 @@ class SmartProjectTest extends TypedDatasetSuite { } test("X4 to X1,X2,X3,X4 projections") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder, D: TypedEncoder](data: Vector[X4[A, B, C, D]]): Prop = { + def prop[ + A: TypedEncoder, + B: TypedEncoder, + C: TypedEncoder, + D: TypedEncoder + ](data: Vector[X4[A, B, C, D]] + ): Prop = { val dataset = TypedDataset.create(data) dataset.project[X4[A, B, C, D]].collect().run().toVector ?= data - dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x => X3(x.a, x.b, x.c)) - dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x => X2(x.a, x.b)) + dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x => + X3(x.a, x.b, x.c) + ) + dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x => + X2(x.a, x.b) + ) dataset.project[X1[A]].collect().run().toVector ?= data.map(x => X1(x.a)) } check(forAll(prop[Int, String, X1[String], Boolean] _)) check(forAll(prop[Short, Long, String, Boolean] _)) check(forAll(prop[Short, (Boolean, Boolean), String, (Int, Int)] _)) - check(forAll(prop[X2[String, Boolean], (Boolean, Boolean), String, Boolean] _)) - check(forAll(prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String, String] _)) + check( + forAll(prop[X2[String, Boolean], (Boolean, Boolean), String, Boolean] _) + ) + check( + forAll( + prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String, String] _ + ) + ) } test("X3U to X1,X2,X3 projections") { - def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder](data: Vector[X3U[A, B, C]]): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder, C: TypedEncoder]( + data: Vector[X3U[A, B, C]] + ): Prop = { val dataset = TypedDataset.create(data) - dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x => X3(x.a, x.b, x.c)) - dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x => X2(x.a, x.b)) + dataset.project[X3[A, B, C]].collect().run().toVector ?= data.map(x => + X3(x.a, x.b, x.c) + ) + dataset.project[X2[A, B]].collect().run().toVector ?= data.map(x => + X2(x.a, x.b) + ) dataset.project[X1[A]].collect().run().toVector ?= data.map(x => X1(x.a)) } @@ -54,6 +77,8 @@ class SmartProjectTest extends TypedDatasetSuite { check(forAll(prop[Short, Long, String] _)) check(forAll(prop[Short, (Boolean, Boolean), String] _)) check(forAll(prop[X2[String, Boolean], (Boolean, Boolean), String] _)) - check(forAll(prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String] _)) + check( + forAll(prop[X2[String, Boolean], X3[Boolean, Boolean, Long], String] _) + ) } } diff --git a/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala index b53000f09..28776ddb2 100644 --- a/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala +++ b/dataset/src/test/scala/frameless/ops/deserialized/FilterTests.scala @@ -7,11 +7,17 @@ import org.scalacheck.Prop._ class FilterTests extends TypedDatasetSuite { test("filter") { - def prop[A: TypedEncoder](filterFunction: A => Boolean, data: Vector[A]): Prop = - TypedDataset.create(data). - deserialized. - filter(filterFunction). - collect().run().toVector =? data.filter(filterFunction) + def prop[A: TypedEncoder]( + filterFunction: A => Boolean, + data: Vector[A] + ): Prop = + TypedDataset + .create(data) + .deserialized + .filter(filterFunction) + .collect() + .run() + .toVector =? data.filter(filterFunction) check(forAll(prop[Int] _)) check(forAll(prop[String] _)) diff --git a/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala index 7dcd0e4e3..409d631e1 100644 --- a/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala +++ b/dataset/src/test/scala/frameless/ops/deserialized/FlatMapTests.scala @@ -7,11 +7,17 @@ import org.scalacheck.Prop._ class FlatMapTests extends TypedDatasetSuite { test("flatMap") { - def prop[A: TypedEncoder, B: TypedEncoder](flatMapFunction: A => Vector[B], data: Vector[A]): Prop = - TypedDataset.create(data). - deserialized. - flatMap(flatMapFunction). - collect().run().toVector =? data.flatMap(flatMapFunction) + def prop[A: TypedEncoder, B: TypedEncoder]( + flatMapFunction: A => Vector[B], + data: Vector[A] + ): Prop = + TypedDataset + .create(data) + .deserialized + .flatMap(flatMapFunction) + .collect() + .run() + .toVector =? data.flatMap(flatMapFunction) check(forAll(prop[Int, Int] _)) check(forAll(prop[Int, String] _)) diff --git a/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala index 06ba04943..def7c9614 100644 --- a/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala +++ b/dataset/src/test/scala/frameless/ops/deserialized/MapPartitionsTests.scala @@ -7,12 +7,18 @@ import org.scalacheck.Prop._ class MapPartitionsTests extends TypedDatasetSuite { test("mapPartitions") { - def prop[A: TypedEncoder, B: TypedEncoder](mapFunction: A => B, data: Vector[A]): Prop = { + def prop[A: TypedEncoder, B: TypedEncoder]( + mapFunction: A => B, + data: Vector[A] + ): Prop = { val lifted: Iterator[A] => Iterator[B] = _.map(mapFunction) - TypedDataset.create(data). - deserialized. - mapPartitions(lifted). - collect().run().toVector =? data.map(mapFunction) + TypedDataset + .create(data) + .deserialized + .mapPartitions(lifted) + .collect() + .run() + .toVector =? data.map(mapFunction) } check(forAll(prop[Int, Int] _)) diff --git a/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala index f7cc0fad0..3cad10439 100644 --- a/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala +++ b/dataset/src/test/scala/frameless/ops/deserialized/MapTests.scala @@ -7,11 +7,17 @@ import org.scalacheck.Prop._ class MapTests extends TypedDatasetSuite { test("map") { - def prop[A: TypedEncoder, B: TypedEncoder](mapFunction: A => B, data: Vector[A]): Prop = - TypedDataset.create(data). - deserialized. - map(mapFunction). - collect().run().toVector =? data.map(mapFunction) + def prop[A: TypedEncoder, B: TypedEncoder]( + mapFunction: A => B, + data: Vector[A] + ): Prop = + TypedDataset + .create(data) + .deserialized + .map(mapFunction) + .collect() + .run() + .toVector =? data.map(mapFunction) check(forAll(prop[Int, Int] _)) check(forAll(prop[Int, String] _)) diff --git a/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala b/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala index 01a074950..934475bb5 100644 --- a/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala +++ b/dataset/src/test/scala/frameless/ops/deserialized/ReduceTests.scala @@ -6,10 +6,16 @@ import org.scalacheck.Prop import org.scalacheck.Prop._ class ReduceTests extends TypedDatasetSuite { - def prop[A: TypedEncoder](reduceFunction: (A, A) => A)(data: Vector[A]): Prop = - TypedDataset.create(data). - deserialized. - reduceOption(reduceFunction).run() =? data.reduceOption(reduceFunction) + + def prop[A: TypedEncoder]( + reduceFunction: (A, A) => A + )(data: Vector[A] + ): Prop = + TypedDataset + .create(data) + .deserialized + .reduceOption(reduceFunction) + .run() =? data.reduceOption(reduceFunction) test("reduce Int") { check(forAll(prop[Int](_ + _) _)) diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala index 82ff375c9..54972786e 100644 --- a/dataset/src/test/scala/frameless/package.scala +++ b/dataset/src/test/scala/frameless/package.scala @@ -1,9 +1,10 @@ import java.time.format.DateTimeFormatter -import java.time.{LocalDateTime => JavaLocalDateTime} +import java.time.{ LocalDateTime => JavaLocalDateTime } -import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.{ Arbitrary, Gen } package object frameless { + /** Fixed decimal point to avoid precision problems specific to Spark */ implicit val arbBigDecimal: Arbitrary[BigDecimal] = Arbitrary { for { @@ -30,7 +31,10 @@ package object frameless { } // see issue with scalacheck non serializable Vector: https://github.com/rickynils/scalacheck/issues/315 - implicit def arbVector[A](implicit A: Arbitrary[A]): Arbitrary[Vector[A]] = + implicit def arbVector[A]( + implicit + A: Arbitrary[A] + ): Arbitrary[Vector[A]] = Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector)) def vectorGen[A: Arbitrary]: Gen[Vector[A]] = arbVector[A].arbitrary @@ -42,7 +46,8 @@ package object frameless { } yield new UdtEncodedClass(int, doubles.toArray) } - val dateTimeFormatter: DateTimeFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm") + val dateTimeFormatter: DateTimeFormatter = + DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm") implicit val localDateArb: Arbitrary[JavaLocalDateTime] = Arbitrary { for { @@ -72,11 +77,10 @@ package object frameless { def anyCauseHas(t: Throwable, f: Throwable => Boolean): Boolean = if (f(t)) true + else if (t.getCause ne null) + anyCauseHas(t.getCause, f) else - if (t.getCause ne null) - anyCauseHas(t.getCause, f) - else - false + false /** * Runs up to maxRuns and outputs the number of failures (times thrown) @@ -85,11 +89,11 @@ package object frameless { * @tparam T * @return the last passing thunk, or null */ - def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T ={ + def runLoads[T](maxRuns: Int = 1000)(thunk: => T): T = { var i = 0 var r = null.asInstanceOf[T] var passed = 0 - while(i < maxRuns){ + while (i < maxRuns) { i += 1 try { r = thunk @@ -98,29 +102,36 @@ package object frameless { println(s"run $i successful") } } catch { - case t: Throwable => System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}") + case t: Throwable => + System.err.println(s"failed unexpectedly on run $i - ${t.getMessage}") } } if (passed != maxRuns) { - System.err.println(s"had ${maxRuns - passed} failures out of $maxRuns runs") + System.err.println( + s"had ${maxRuns - passed} failures out of $maxRuns runs" + ) } r } - /** + /** * Runs a given thunk up to maxRuns times, restarting the thunk if tolerantOf the thrown Throwable is true * @param tolerantOf * @param maxRuns default of 20 * @param thunk * @return either a successful run result or the last error will be thrown */ - def tolerantRun[T](tolerantOf: Throwable => Boolean, maxRuns: Int = 20)(thunk: => T): T ={ + def tolerantRun[T]( + tolerantOf: Throwable => Boolean, + maxRuns: Int = 20 + )(thunk: => T + ): T = { var passed = false var i = 0 var res: T = null.asInstanceOf[T] var thrown: Throwable = null - while((i < maxRuns) && !passed) { + while ((i < maxRuns) && !passed) { try { i += 1 res = thunk diff --git a/dataset/src/test/scala/frameless/sql/package.scala b/dataset/src/test/scala/frameless/sql/package.scala index fcb45b03d..936d78042 100644 --- a/dataset/src/test/scala/frameless/sql/package.scala +++ b/dataset/src/test/scala/frameless/sql/package.scala @@ -1,16 +1,18 @@ package frameless import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.expressions.{And, Or} +import org.apache.spark.sql.catalyst.expressions.{ And, Or } package object sql { + implicit class ExpressionOps(val self: Expression) extends AnyVal { + def toList: List[Expression] = { def rec(expr: Expression, acc: List[Expression]): List[Expression] = { expr match { case And(left, right) => rec(left, rec(right, acc)) - case Or(left, right) => rec(left, rec(right, acc)) - case e => e +: acc + case Or(left, right) => rec(left, rec(right, acc)) + case e => e +: acc } } diff --git a/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala b/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala index 8555d1809..c4cce0b14 100644 --- a/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala +++ b/dataset/src/test/scala/frameless/sql/rules/SQLRulesSuite.scala @@ -11,30 +11,36 @@ import org.scalatest.Assertion import org.scalatest.matchers.should.Matchers trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self => + protected lazy val path: String = { val tmpDir = System.getProperty("java.io.tmpdir") s"$tmpDir/${self.getClass.getName}" } - def withDataset[A: TypedEncoder: CatalystOrdered](payload: A)(f: TypedDataset[A] => Assertion): Assertion = { + def withDataset[A: TypedEncoder: CatalystOrdered]( + payload: A + )(f: TypedDataset[A] => Assertion + ): Assertion = { TypedDataset.create(Seq(payload)).write.mode("overwrite").parquet(path) f(TypedDataset.createUnsafe[A](session.read.parquet(path))) } def predicatePushDownTest[A: TypedEncoder: CatalystOrdered]( - expected: X1[A], - expectedPushDownFilters: List[Filter], - planShouldNotContain: PartialFunction[Expression, Expression], - op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean] - ): Assertion = { + expected: X1[A], + expectedPushDownFilters: List[Filter], + planShouldNotContain: PartialFunction[Expression, Expression], + op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean] + ): Assertion = { withDataset(expected) { dataset => val ds = dataset.filter(op(dataset('a))) val actualPushDownFilters = pushDownFilters(ds) - val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList) + val optimizedPlan = ds.queryExecution.optimizedPlan.collect { + case logical.Filter(condition, _) => condition + }.flatMap(_.toList) // check the optimized plan - optimizedPlan.collectFirst(planShouldNotContain) should be (empty) + optimizedPlan.collectFirst(planShouldNotContain) should be(empty) // compare filters actualPushDownFilters shouldBe expectedPushDownFilters @@ -53,18 +59,22 @@ trait SQLRulesSuite extends TypedDatasetSuite with Matchers { self => if (sparkPlan.children.isEmpty) // assume it's AQE sparkPlan match { case aq: AdaptiveSparkPlanExec => aq.initialPlan - case _ => sparkPlan + case _ => sparkPlan } else sparkPlan initialPlan.collect { case fs: FileSourceScanExec => - import scala.reflect.runtime.{universe => ru} + import scala.reflect.runtime.{ universe => ru } val runtimeMirror = ru.runtimeMirror(getClass.getClassLoader) val instanceMirror = runtimeMirror.reflect(fs) - val getter = ru.typeOf[FileSourceScanExec].member(ru.TermName("pushedDownFilters")).asTerm.getter + val getter = ru + .typeOf[FileSourceScanExec] + .member(ru.TermName("pushedDownFilters")) + .asTerm + .getter val m = instanceMirror.reflectMethod(getter.asMethod) val res = m.apply(fs).asInstanceOf[Seq[Filter]] diff --git a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala index 5108ed581..51839e0a3 100644 --- a/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala +++ b/dataset/src/test/scala/frameless/syntax/FramelessSyntaxTests.scala @@ -9,26 +9,37 @@ class FramelessSyntaxTests extends TypedDatasetSuite { // Hide the implicit SparkDelay[Job] on TypedDatasetSuite to avoid ambiguous implicits override val sparkDelay = null - def prop[A, B](data: Vector[X2[A, B]])( - implicit ev: TypedEncoder[X2[A, B]] - ): Prop = { + def prop[A, B]( + data: Vector[X2[A, B]] + )(implicit + ev: TypedEncoder[X2[A, B]] + ): Prop = { val dataset = TypedDataset.create(data).dataset val dataframe = dataset.toDF() val typedDataset = dataset.typed val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]] - typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame.collect().run().toVector + typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame + .collect() + .run() + .toVector } test("dataset typed - toTyped") { - def prop[A, B](data: Vector[X2[A, B]])( - implicit ev: TypedEncoder[X2[A, B]] - ): Prop = { - val dataset = session.createDataset(data)(TypedExpressionEncoder(ev)).typed + def prop[A, B]( + data: Vector[X2[A, B]] + )(implicit + ev: TypedEncoder[X2[A, B]] + ): Prop = { + val dataset = + session.createDataset(data)(TypedExpressionEncoder(ev)).typed val dataframe = dataset.toDF() - dataset.collect().run().toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector + dataset + .collect() + .run() + .toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector } check(forAll(prop[Int, String] _)) @@ -38,8 +49,13 @@ class FramelessSyntaxTests extends TypedDatasetSuite { test("frameless typed column and aggregate") { def prop[A: TypedEncoder](a: A, b: A): Prop = { val d = TypedDataset.create((a, b) :: Nil) - (d.select(d('_1).untyped.typedColumn).collect().run ?= d.select(d('_1)).collect().run).&&( - d.agg(first(d('_1))).collect().run() ?= d.agg(first(d('_1)).untyped.typedAggregate).collect().run() + (d.select(d('_1).untyped.typedColumn) + .collect() + .run ?= d.select(d('_1)).collect().run).&&( + d.agg(first(d('_1))).collect().run() ?= d + .agg(first(d('_1)).untyped.typedAggregate) + .collect() + .run() ) } diff --git a/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala b/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala index a28ad0820..c6e4b8c29 100644 --- a/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala +++ b/dataset/src/test/scala/org/apache/hadoop/fs/local/StreamingFS.scala @@ -3,5 +3,11 @@ package org.apache.hadoop.fs.local import com.globalmentor.apache.hadoop.fs.BareLocalFileSystem import org.apache.hadoop.fs.DelegateToFileSystem -class StreamingFS(uri: java.net.URI, conf: org.apache.hadoop.conf.Configuration) extends - DelegateToFileSystem(uri, new BareLocalFileSystem(), conf, "file", false) {} +class StreamingFS(uri: java.net.URI, conf: org.apache.hadoop.conf.Configuration) + extends DelegateToFileSystem( + uri, + new BareLocalFileSystem(), + conf, + "file", + false + ) {} diff --git a/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala b/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala index c44ac4d08..6224dc78f 100644 --- a/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala +++ b/dataset/src/test/spark-3.2/frameless/sql/rules/FramelessLitPushDownTests.scala @@ -3,10 +3,17 @@ package frameless.sql.rules import frameless._ import frameless.sql._ import frameless.functions.Lit -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant} -import org.apache.spark.sql.sources.{Filter, IsNotNull} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{ + currentTimestamp, + microsToInstant +} +import org.apache.spark.sql.sources.{ Filter, IsNotNull } import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericRowWithSchema} +import org.apache.spark.sql.catalyst.expressions.{ + Cast, + Expression, + GenericRowWithSchema +} import java.time.Instant import org.apache.spark.sql.catalyst.plans.logical @@ -45,7 +52,10 @@ class FramelessLitPushDownTests extends SQLRulesSuite { test("struct push-down") { type Payload = X4[Int, Int, Int, Int] val expectedStructure = X1(X4(1, 2, 3, 4)) - val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema) + val expected = new GenericRowWithSchema( + Array(1, 2, 3, 4), + TypedExpressionEncoder[Payload].schema + ) val expectedPushDownFilters = List(IsNotNull("a")) predicatePushDownTest[Payload]( @@ -58,16 +68,18 @@ class FramelessLitPushDownTests extends SQLRulesSuite { } override def predicatePushDownTest[A: TypedEncoder: CatalystOrdered]( - expected: X1[A], - expectedPushDownFilters: List[Filter], - planShouldContain: PartialFunction[Expression, Expression], - op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean] - ): Assertion = { + expected: X1[A], + expectedPushDownFilters: List[Filter], + planShouldContain: PartialFunction[Expression, Expression], + op: TypedColumn[X1[A], A] => TypedColumn[X1[A], Boolean] + ): Assertion = { withDataset(expected) { dataset => val ds = dataset.filter(op(dataset('a))) val actualPushDownFilters = pushDownFilters(ds) - val optimizedPlan = ds.queryExecution.optimizedPlan.collect { case logical.Filter(condition, _) => condition }.flatMap(_.toList) + val optimizedPlan = ds.queryExecution.optimizedPlan.collect { + case logical.Filter(condition, _) => condition + }.flatMap(_.toList) // check the optimized plan optimizedPlan.collectFirst(planShouldContain) should not be (empty) diff --git a/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala index 36a443fb5..93e723c2e 100644 --- a/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala +++ b/dataset/src/test/spark-3.3+/frameless/sql/rules/FramelessLitPushDownTests.scala @@ -2,8 +2,11 @@ package frameless.sql.rules import frameless._ import frameless.functions.Lit -import org.apache.spark.sql.catalyst.util.DateTimeUtils.{currentTimestamp, microsToInstant} -import org.apache.spark.sql.sources.{EqualTo, GreaterThanOrEqual, IsNotNull} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{ + currentTimestamp, + microsToInstant +} +import org.apache.spark.sql.sources.{ EqualTo, GreaterThanOrEqual, IsNotNull } import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import java.time.Instant @@ -14,7 +17,8 @@ class FramelessLitPushDownTests extends SQLRulesSuite { test("java.sql.Timestamp push-down") { val expected = java.sql.Timestamp.from(microsToInstant(now)) val expectedStructure = X1(SQLTimestamp(now)) - val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) + val expectedPushDownFilters = + List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) predicatePushDownTest[SQLTimestamp]( expectedStructure, @@ -27,7 +31,8 @@ class FramelessLitPushDownTests extends SQLRulesSuite { test("java.time.Instant push-down") { val expected = java.sql.Timestamp.from(microsToInstant(now)) val expectedStructure = X1(microsToInstant(now)) - val expectedPushDownFilters = List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) + val expectedPushDownFilters = + List(IsNotNull("a"), GreaterThanOrEqual("a", expected)) predicatePushDownTest[Instant]( expectedStructure, @@ -40,7 +45,10 @@ class FramelessLitPushDownTests extends SQLRulesSuite { test("struct push-down") { type Payload = X4[Int, Int, Int, Int] val expectedStructure = X1(X4(1, 2, 3, 4)) - val expected = new GenericRowWithSchema(Array(1, 2, 3, 4), TypedExpressionEncoder[Payload].schema) + val expected = new GenericRowWithSchema( + Array(1, 2, 3, 4), + TypedExpressionEncoder[Payload].schema + ) val expectedPushDownFilters = List(IsNotNull("a"), EqualTo("a", expected)) predicatePushDownTest[Payload]( diff --git a/ml/src/main/scala/frameless/ml/TypedEstimator.scala b/ml/src/main/scala/frameless/ml/TypedEstimator.scala index 2628d234a..6c6ca6029 100644 --- a/ml/src/main/scala/frameless/ml/TypedEstimator.scala +++ b/ml/src/main/scala/frameless/ml/TypedEstimator.scala @@ -2,19 +2,20 @@ package frameless package ml import frameless.ops.SmartProject -import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.{ Estimator, Model } /** - * A TypedEstimator fits models to data. - */ + * A TypedEstimator fits models to data. + */ trait TypedEstimator[Inputs, Outputs, M <: Model[M]] { val estimator: Estimator[M] - def fit[T, F[_]](ds: TypedDataset[T])( - implicit - smartProject: SmartProject[T, Inputs], - F: SparkDelay[F] - ): F[AppendTransformer[Inputs, Outputs, M]] = { + def fit[T, F[_]]( + ds: TypedDataset[T] + )(implicit + smartProject: SmartProject[T, Inputs], + F: SparkDelay[F] + ): F[AppendTransformer[Inputs, Outputs, M]] = { implicit val sparkSession = ds.dataset.sparkSession F.delay { val inputDs = smartProject.apply(ds) diff --git a/ml/src/main/scala/frameless/ml/TypedTransformer.scala b/ml/src/main/scala/frameless/ml/TypedTransformer.scala index 99edc70e6..68dbd836c 100644 --- a/ml/src/main/scala/frameless/ml/TypedTransformer.scala +++ b/ml/src/main/scala/frameless/ml/TypedTransformer.scala @@ -3,31 +3,34 @@ package ml import frameless.ops.SmartProject import org.apache.spark.ml.Transformer -import shapeless.{Generic, HList} -import shapeless.ops.hlist.{Prepend, Tupler} +import shapeless.{ Generic, HList } +import shapeless.ops.hlist.{ Prepend, Tupler } /** - * A TypedTransformer transforms one TypedDataset into another. - */ + * A TypedTransformer transforms one TypedDataset into another. + */ sealed trait TypedTransformer /** - * An AppendTransformer `transform` method takes as input a TypedDataset containing `Inputs` and - * return a TypedDataset with `Outputs` columns appended to the input TypedDataset. - */ -trait AppendTransformer[Inputs, Outputs, InnerTransformer <: Transformer] extends TypedTransformer { + * An AppendTransformer `transform` method takes as input a TypedDataset containing `Inputs` and + * return a TypedDataset with `Outputs` columns appended to the input TypedDataset. + */ +trait AppendTransformer[Inputs, Outputs, InnerTransformer <: Transformer] + extends TypedTransformer { val transformer: InnerTransformer - def transform[T, TVals <: HList, OutputsVals <: HList, OutVals <: HList, Out](ds: TypedDataset[T])( - implicit - i0: SmartProject[T, Inputs], - i1: Generic.Aux[T, TVals], - i2: Generic.Aux[Outputs, OutputsVals], - i3: Prepend.Aux[TVals, OutputsVals, OutVals], - i4: Tupler.Aux[OutVals, Out], - i5: TypedEncoder[Out] - ): TypedDataset[Out] = { - val transformed = transformer.transform(ds.dataset).as[Out](TypedExpressionEncoder[Out]) + def transform[T, TVals <: HList, OutputsVals <: HList, OutVals <: HList, Out]( + ds: TypedDataset[T] + )(implicit + i0: SmartProject[T, Inputs], + i1: Generic.Aux[T, TVals], + i2: Generic.Aux[Outputs, OutputsVals], + i3: Prepend.Aux[TVals, OutputsVals, OutVals], + i4: Tupler.Aux[OutVals, Out], + i5: TypedEncoder[Out] + ): TypedDataset[Out] = { + val transformed = + transformer.transform(ds.dataset).as[Out](TypedExpressionEncoder[Out]) TypedDataset.create[Out](transformed) } diff --git a/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala b/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala index f6efcceaf..1bd757478 100644 --- a/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala +++ b/ml/src/main/scala/frameless/ml/classification/TypedRandomForestClassifier.scala @@ -4,48 +4,82 @@ package classification import frameless.ml.internals.TreesInputsChecker import frameless.ml.params.trees.FeatureSubsetStrategy -import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier} +import org.apache.spark.ml.classification.{ + RandomForestClassificationModel, + RandomForestClassifier +} import org.apache.spark.ml.linalg.Vector /** - * Random Forest learning algorithm for - * classification. - * It supports both binary and multiclass labels, as well as both continuous and categorical - * features. - */ -final class TypedRandomForestClassifier[Inputs] private[ml]( - rf: RandomForestClassifier, - labelCol: String, - featuresCol: String -) extends TypedEstimator[Inputs, TypedRandomForestClassifier.Outputs, RandomForestClassificationModel] { + * Random Forest learning algorithm for + * classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +final class TypedRandomForestClassifier[Inputs] private[ml] ( + rf: RandomForestClassifier, + labelCol: String, + featuresCol: String) + extends TypedEstimator[ + Inputs, + TypedRandomForestClassifier.Outputs, + RandomForestClassificationModel + ] { val estimator: RandomForestClassifier = - rf - .setLabelCol(labelCol) + rf.setLabelCol(labelCol) .setFeaturesCol(featuresCol) .setPredictionCol(AppendTransformer.tempColumnName) .setRawPredictionCol(AppendTransformer.tempColumnName2) .setProbabilityCol(AppendTransformer.tempColumnName3) - def setNumTrees(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setNumTrees(value)) - def setMaxDepth(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxDepth(value)) - def setMinInfoGain(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInfoGain(value)) - def setMinInstancesPerNode(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMinInstancesPerNode(value)) - def setMaxMemoryInMB(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxMemoryInMB(value)) - def setSubsamplingRate(value: Double): TypedRandomForestClassifier[Inputs] = copy(rf.setSubsamplingRate(value)) - def setFeatureSubsetStrategy(value: FeatureSubsetStrategy): TypedRandomForestClassifier[Inputs] = + def setNumTrees(value: Int): TypedRandomForestClassifier[Inputs] = + copy(rf.setNumTrees(value)) + + def setMaxDepth(value: Int): TypedRandomForestClassifier[Inputs] = + copy(rf.setMaxDepth(value)) + + def setMinInfoGain(value: Double): TypedRandomForestClassifier[Inputs] = + copy(rf.setMinInfoGain(value)) + + def setMinInstancesPerNode(value: Int): TypedRandomForestClassifier[Inputs] = + copy(rf.setMinInstancesPerNode(value)) + + def setMaxMemoryInMB(value: Int): TypedRandomForestClassifier[Inputs] = + copy(rf.setMaxMemoryInMB(value)) + + def setSubsamplingRate(value: Double): TypedRandomForestClassifier[Inputs] = + copy(rf.setSubsamplingRate(value)) + + def setFeatureSubsetStrategy( + value: FeatureSubsetStrategy + ): TypedRandomForestClassifier[Inputs] = copy(rf.setFeatureSubsetStrategy(value.sparkValue)) - def setMaxBins(value: Int): TypedRandomForestClassifier[Inputs] = copy(rf.setMaxBins(value)) - private def copy(newRf: RandomForestClassifier): TypedRandomForestClassifier[Inputs] = + def setMaxBins(value: Int): TypedRandomForestClassifier[Inputs] = + copy(rf.setMaxBins(value)) + + private def copy( + newRf: RandomForestClassifier + ): TypedRandomForestClassifier[Inputs] = new TypedRandomForestClassifier[Inputs](newRf, labelCol, featuresCol) } object TypedRandomForestClassifier { - case class Outputs(rawPrediction: Vector, probability: Vector, prediction: Double) - def apply[Inputs](implicit inputsChecker: TreesInputsChecker[Inputs]): TypedRandomForestClassifier[Inputs] = { - new TypedRandomForestClassifier(new RandomForestClassifier(), inputsChecker.labelCol, inputsChecker.featuresCol) + case class Outputs( + rawPrediction: Vector, + probability: Vector, + prediction: Double) + + def apply[Inputs]( + implicit + inputsChecker: TreesInputsChecker[Inputs] + ): TypedRandomForestClassifier[Inputs] = { + new TypedRandomForestClassifier( + new RandomForestClassifier(), + inputsChecker.labelCol, + inputsChecker.featuresCol + ) } } - diff --git a/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala b/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala index 4a8c974b4..210c9e589 100644 --- a/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala +++ b/ml/src/main/scala/frameless/ml/clustering/TypedBisectingKMeans.scala @@ -3,39 +3,46 @@ package ml package classification import frameless.ml.internals.VectorInputsChecker -import org.apache.spark.ml.clustering.{BisectingKMeans, BisectingKMeansModel} +import org.apache.spark.ml.clustering.{ BisectingKMeans, BisectingKMeansModel } /** - * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" - * by Steinbach, Karypis, and Kumar, with modification to fit Spark. - * The algorithm starts from a single cluster that contains all points. - * Iteratively it finds divisible clusters on the bottom level and bisects each of them using - * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. - * The bisecting steps of clusters on the same level are grouped together to increase parallelism. - * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, - * larger clusters get higher priority. - * - * @see - * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, - * KDD Workshop on Text Mining, 2000. - */ + * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques" + * by Steinbach, Karypis, and Kumar, with modification to fit Spark. + * The algorithm starts from a single cluster that contains all points. + * Iteratively it finds divisible clusters on the bottom level and bisects each of them using + * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible. + * The bisecting steps of clusters on the same level are grouped together to increase parallelism. + * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters, + * larger clusters get higher priority. + * + * @see + * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques, + * KDD Workshop on Text Mining, 2000. + */ class TypedBisectingKMeans[Inputs] private[ml] ( - bkm: BisectingKMeans, - featuresCol: String -) extends TypedEstimator[Inputs,TypedBisectingKMeans.Output, BisectingKMeansModel]{ + bkm: BisectingKMeans, + featuresCol: String) + extends TypedEstimator[ + Inputs, + TypedBisectingKMeans.Output, + BisectingKMeansModel + ] { + val estimator: BisectingKMeans = bkm - .setFeaturesCol(featuresCol) - .setPredictionCol(AppendTransformer.tempColumnName) - + .setFeaturesCol(featuresCol) + .setPredictionCol(AppendTransformer.tempColumnName) + def setK(value: Int): TypedBisectingKMeans[Inputs] = copy(bkm.setK(value)) - - def setMaxIter(value: Int): TypedBisectingKMeans[Inputs] = copy(bkm.setMaxIter(value)) + + def setMaxIter(value: Int): TypedBisectingKMeans[Inputs] = + copy(bkm.setMaxIter(value)) def setMinDivisibleClusterSize(value: Double): TypedBisectingKMeans[Inputs] = copy(bkm.setMinDivisibleClusterSize(value)) - - def setSeed(value: Long): TypedBisectingKMeans[Inputs] = copy(bkm.setSeed(value)) + + def setSeed(value: Long): TypedBisectingKMeans[Inputs] = + copy(bkm.setSeed(value)) private def copy(newBkm: BisectingKMeans): TypedBisectingKMeans[Inputs] = new TypedBisectingKMeans[Inputs](newBkm, featuresCol) @@ -44,6 +51,9 @@ class TypedBisectingKMeans[Inputs] private[ml] ( object TypedBisectingKMeans { case class Output(prediction: Int) - def apply[Inputs]()(implicit inputsChecker: VectorInputsChecker[Inputs]): TypedBisectingKMeans[Inputs] = + def apply[Inputs]( + )(implicit + inputsChecker: VectorInputsChecker[Inputs] + ): TypedBisectingKMeans[Inputs] = new TypedBisectingKMeans(new BisectingKMeans(), inputsChecker.featuresCol) -} \ No newline at end of file +} diff --git a/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala b/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala index 1a32076a5..ed082a20b 100644 --- a/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala +++ b/ml/src/main/scala/frameless/ml/clustering/TypedKMeans.scala @@ -4,27 +4,29 @@ package classification import frameless.ml.internals.VectorInputsChecker import frameless.ml.params.kmeans.KMeansInitMode -import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.clustering.{ KMeans, KMeansModel } /** - * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. - * - * @see Bahmani et al., Scalable k-means++. - */ + * K-means clustering with support for k-means|| initialization proposed by Bahmani et al. + * + * @see Bahmani et al., Scalable k-means++. + */ class TypedKMeans[Inputs] private[ml] ( - km: KMeans, - featuresCol: String -) extends TypedEstimator[Inputs,TypedKMeans.Output,KMeansModel] { + km: KMeans, + featuresCol: String) + extends TypedEstimator[Inputs, TypedKMeans.Output, KMeansModel] { + val estimator: KMeans = - km - .setFeaturesCol(featuresCol) + km.setFeaturesCol(featuresCol) .setPredictionCol(AppendTransformer.tempColumnName) def setK(value: Int): TypedKMeans[Inputs] = copy(km.setK(value)) - def setInitMode(value: KMeansInitMode): TypedKMeans[Inputs] = copy(km.setInitMode(value.sparkValue)) + def setInitMode(value: KMeansInitMode): TypedKMeans[Inputs] = + copy(km.setInitMode(value.sparkValue)) - def setInitSteps(value: Int): TypedKMeans[Inputs] = copy(km.setInitSteps(value)) + def setInitSteps(value: Int): TypedKMeans[Inputs] = + copy(km.setInitSteps(value)) def setMaxIter(value: Int): TypedKMeans[Inputs] = copy(km.setMaxIter(value)) @@ -32,14 +34,18 @@ class TypedKMeans[Inputs] private[ml] ( def setSeed(value: Long): TypedKMeans[Inputs] = copy(km.setSeed(value)) - private def copy(newKmeans: KMeans): TypedKMeans[Inputs] = new TypedKMeans[Inputs](newKmeans, featuresCol) + private def copy(newKmeans: KMeans): TypedKMeans[Inputs] = + new TypedKMeans[Inputs](newKmeans, featuresCol) } -object TypedKMeans{ +object TypedKMeans { case class Output(prediction: Int) - def apply[Inputs](implicit inputsChecker: VectorInputsChecker[Inputs]): TypedKMeans[Inputs] = { + def apply[Inputs]( + implicit + inputsChecker: VectorInputsChecker[Inputs] + ): TypedKMeans[Inputs] = { new TypedKMeans(new KMeans(), inputsChecker.featuresCol) } } diff --git a/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala b/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala index af2e9684a..f650ed678 100644 --- a/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala +++ b/ml/src/main/scala/frameless/ml/feature/TypedIndexToString.scala @@ -6,14 +6,20 @@ import frameless.ml.internals.UnaryInputsChecker import org.apache.spark.ml.feature.IndexToString /** - * A `TypedTransformer` that maps a column of indices back to a new column of corresponding - * string values. - * The index-string mapping must be supplied when creating the `TypedIndexToString`. - * - * @see `TypedStringIndexer` for converting strings into indices - */ -final class TypedIndexToString[Inputs] private[ml](indexToString: IndexToString, inputCol: String) - extends AppendTransformer[Inputs, TypedIndexToString.Outputs, IndexToString] { + * A `TypedTransformer` that maps a column of indices back to a new column of corresponding + * string values. + * The index-string mapping must be supplied when creating the `TypedIndexToString`. + * + * @see `TypedStringIndexer` for converting strings into indices + */ +final class TypedIndexToString[Inputs] private[ml] ( + indexToString: IndexToString, + inputCol: String) + extends AppendTransformer[ + Inputs, + TypedIndexToString.Outputs, + IndexToString + ] { val transformer: IndexToString = indexToString @@ -25,8 +31,14 @@ final class TypedIndexToString[Inputs] private[ml](indexToString: IndexToString, object TypedIndexToString { case class Outputs(originalOutput: String) - def apply[Inputs](labels: Array[String]) - (implicit inputsChecker: UnaryInputsChecker[Inputs, Double]): TypedIndexToString[Inputs] = { - new TypedIndexToString[Inputs](new IndexToString().setLabels(labels), inputsChecker.inputCol) + def apply[Inputs]( + labels: Array[String] + )(implicit + inputsChecker: UnaryInputsChecker[Inputs, Double] + ): TypedIndexToString[Inputs] = { + new TypedIndexToString[Inputs]( + new IndexToString().setLabels(labels), + inputsChecker.inputCol + ) } -} \ No newline at end of file +} diff --git a/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala b/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala index 7eba8e306..d398e3581 100644 --- a/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala +++ b/ml/src/main/scala/frameless/ml/feature/TypedStringIndexer.scala @@ -4,25 +4,34 @@ package feature import frameless.ml.feature.TypedStringIndexer.HandleInvalid import frameless.ml.internals.UnaryInputsChecker -import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel} +import org.apache.spark.ml.feature.{ StringIndexer, StringIndexerModel } /** - * A label indexer that maps a string column of labels to an ML column of label indices. - * The indices are in [0, numLabels), ordered by label frequencies. - * So the most frequent label gets index 0. - * - * @see `TypedIndexToString` for the inverse transformation - */ -final class TypedStringIndexer[Inputs] private[ml](stringIndexer: StringIndexer, inputCol: String) - extends TypedEstimator[Inputs, TypedStringIndexer.Outputs, StringIndexerModel] { + * A label indexer that maps a string column of labels to an ML column of label indices. + * The indices are in [0, numLabels), ordered by label frequencies. + * So the most frequent label gets index 0. + * + * @see `TypedIndexToString` for the inverse transformation + */ +final class TypedStringIndexer[Inputs] private[ml] ( + stringIndexer: StringIndexer, + inputCol: String) + extends TypedEstimator[ + Inputs, + TypedStringIndexer.Outputs, + StringIndexerModel + ] { val estimator: StringIndexer = stringIndexer .setInputCol(inputCol) .setOutputCol(AppendTransformer.tempColumnName) - def setHandleInvalid(value: HandleInvalid): TypedStringIndexer[Inputs] = copy(stringIndexer.setHandleInvalid(value.sparkValue)) + def setHandleInvalid(value: HandleInvalid): TypedStringIndexer[Inputs] = + copy(stringIndexer.setHandleInvalid(value.sparkValue)) - private def copy(newStringIndexer: StringIndexer): TypedStringIndexer[Inputs] = + private def copy( + newStringIndexer: StringIndexer + ): TypedStringIndexer[Inputs] = new TypedStringIndexer[Inputs](newStringIndexer, inputCol) } @@ -30,13 +39,17 @@ object TypedStringIndexer { case class Outputs(indexedOutput: Double) sealed abstract class HandleInvalid(val sparkValue: String) + object HandleInvalid { case object Error extends HandleInvalid("error") case object Skip extends HandleInvalid("skip") case object Keep extends HandleInvalid("keep") } - def apply[Inputs](implicit inputsChecker: UnaryInputsChecker[Inputs, String]): TypedStringIndexer[Inputs] = { + def apply[Inputs]( + implicit + inputsChecker: UnaryInputsChecker[Inputs, String] + ): TypedStringIndexer[Inputs] = { new TypedStringIndexer[Inputs](new StringIndexer(), inputsChecker.inputCol) } -} \ No newline at end of file +} diff --git a/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala b/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala index d599011b3..4008462b9 100644 --- a/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala +++ b/ml/src/main/scala/frameless/ml/feature/TypedVectorAssembler.scala @@ -4,17 +4,23 @@ package feature import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vector -import shapeless.{HList, HNil, LabelledGeneric} +import shapeless.{ HList, HNil, LabelledGeneric } import shapeless.ops.hlist.ToTraversable -import shapeless.ops.record.{Keys, Values} +import shapeless.ops.record.{ Keys, Values } import shapeless._ import scala.annotation.implicitNotFound /** - * A feature transformer that merges multiple columns into a vector column. - */ -final class TypedVectorAssembler[Inputs] private[ml](vectorAssembler: VectorAssembler, inputCols: Array[String]) - extends AppendTransformer[Inputs, TypedVectorAssembler.Output, VectorAssembler] { + * A feature transformer that merges multiple columns into a vector column. + */ +final class TypedVectorAssembler[Inputs] private[ml] ( + vectorAssembler: VectorAssembler, + inputCols: Array[String]) + extends AppendTransformer[ + Inputs, + TypedVectorAssembler.Output, + VectorAssembler + ] { val transformer: VectorAssembler = vectorAssembler .setInputCols(inputCols) @@ -25,8 +31,14 @@ final class TypedVectorAssembler[Inputs] private[ml](vectorAssembler: VectorAsse object TypedVectorAssembler { case class Output(vector: Vector) - def apply[Inputs](implicit inputsChecker: TypedVectorAssemblerInputsChecker[Inputs]): TypedVectorAssembler[Inputs] = { - new TypedVectorAssembler(new VectorAssembler(), inputsChecker.inputCols.toArray) + def apply[Inputs]( + implicit + inputsChecker: TypedVectorAssemblerInputsChecker[Inputs] + ): TypedVectorAssembler[Inputs] = { + new TypedVectorAssembler( + new VectorAssembler(), + inputsChecker.inputCols.toArray + ) } } @@ -38,32 +50,41 @@ private[ml] trait TypedVectorAssemblerInputsChecker[Inputs] { } private[ml] object TypedVectorAssemblerInputsChecker { - implicit def checkInputs[Inputs, InputsRec <: HList, InputsKeys <: HList, InputsVals <: HList]( - implicit - inputsGen: LabelledGeneric.Aux[Inputs, InputsRec], - inputsKeys: Keys.Aux[InputsRec, InputsKeys], - inputsKeysTraverse: ToTraversable.Aux[InputsKeys, Seq, Symbol], - inputsValues: Values.Aux[InputsRec, InputsVals], - inputsTypeCheck: TypedVectorAssemblerInputsValueChecker[InputsVals] - ): TypedVectorAssemblerInputsChecker[Inputs] = new TypedVectorAssemblerInputsChecker[Inputs] { - val inputCols: Seq[String] = inputsKeys.apply().to[Seq].map(_.name) - } + + implicit def checkInputs[ + Inputs, + InputsRec <: HList, + InputsKeys <: HList, + InputsVals <: HList + ](implicit + inputsGen: LabelledGeneric.Aux[Inputs, InputsRec], + inputsKeys: Keys.Aux[InputsRec, InputsKeys], + inputsKeysTraverse: ToTraversable.Aux[InputsKeys, Seq, Symbol], + inputsValues: Values.Aux[InputsRec, InputsVals], + inputsTypeCheck: TypedVectorAssemblerInputsValueChecker[InputsVals] + ): TypedVectorAssemblerInputsChecker[Inputs] = + new TypedVectorAssemblerInputsChecker[Inputs] { + val inputCols: Seq[String] = inputsKeys.apply().to[Seq].map(_.name) + } } private[ml] trait TypedVectorAssemblerInputsValueChecker[InputsVals] private[ml] object TypedVectorAssemblerInputsValueChecker { + implicit def hnilCheckInputsValue: TypedVectorAssemblerInputsValueChecker[HNil] = new TypedVectorAssemblerInputsValueChecker[HNil] {} implicit def hlistCheckInputsValueNumeric[H, T <: HList]( - implicit ch: CatalystNumeric[H], - tt: TypedVectorAssemblerInputsValueChecker[T] - ): TypedVectorAssemblerInputsValueChecker[H :: T] = new TypedVectorAssemblerInputsValueChecker[H :: T] {} + implicit + ch: CatalystNumeric[H], + tt: TypedVectorAssemblerInputsValueChecker[T] + ): TypedVectorAssemblerInputsValueChecker[H :: T] = + new TypedVectorAssemblerInputsValueChecker[H :: T] {} implicit def hlistCheckInputsValueBoolean[T <: HList]( - implicit tt: TypedVectorAssemblerInputsValueChecker[T] - ): TypedVectorAssemblerInputsValueChecker[Boolean :: T] = new TypedVectorAssemblerInputsValueChecker[Boolean :: T] {} + implicit + tt: TypedVectorAssemblerInputsValueChecker[T] + ): TypedVectorAssemblerInputsValueChecker[Boolean :: T] = + new TypedVectorAssemblerInputsValueChecker[Boolean :: T] {} } - - diff --git a/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala index 995a3f961..edc645397 100644 --- a/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala +++ b/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala @@ -4,13 +4,13 @@ package internals import org.apache.spark.ml.linalg._ import shapeless.ops.hlist.Length -import shapeless.{HList, LabelledGeneric, Nat, Witness} +import shapeless.{ HList, LabelledGeneric, Nat, Witness } import scala.annotation.implicitNotFound /** - * Can be used for linear reg algorithm - */ + * Can be used for linear reg algorithm + */ @implicitNotFound( msg = "Cannot prove that ${Inputs} is a valid input type. " + "Input type must only contain a field of type Double (the label) and a field of type " + @@ -25,18 +25,18 @@ trait LinearInputsChecker[Inputs] { object LinearInputsChecker { implicit def checkLinearInputs[ - Inputs, - InputsRec <: HList, - LabelK <: Symbol, - FeaturesK <: Symbol]( - implicit - i0: LabelledGeneric.Aux[Inputs, InputsRec], - i1: Length.Aux[InputsRec, Nat._2], - i2: SelectorByValue.Aux[InputsRec, Double, LabelK], - i3: Witness.Aux[LabelK], - i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], - i5: Witness.Aux[FeaturesK] - ): LinearInputsChecker[Inputs] = { + Inputs, + InputsRec <: HList, + LabelK <: Symbol, + FeaturesK <: Symbol + ](implicit + i0: LabelledGeneric.Aux[Inputs, InputsRec], + i1: Length.Aux[InputsRec, Nat._2], + i2: SelectorByValue.Aux[InputsRec, Double, LabelK], + i3: Witness.Aux[LabelK], + i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], + i5: Witness.Aux[FeaturesK] + ): LinearInputsChecker[Inputs] = { new LinearInputsChecker[Inputs] { val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name @@ -45,25 +45,27 @@ object LinearInputsChecker { } implicit def checkLinearInputs2[ - Inputs, - InputsRec <: HList, - LabelK <: Symbol, - FeaturesK <: Symbol, - WeightK <: Symbol]( - implicit - i0: LabelledGeneric.Aux[Inputs, InputsRec], - i1: Length.Aux[InputsRec, Nat._3], - i2: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], - i3: Witness.Aux[FeaturesK], - i4: SelectorByValue.Aux[InputsRec, Double, LabelK], - i5: Witness.Aux[LabelK], - i6: SelectorByValue.Aux[InputsRec, Float, WeightK], - i7: Witness.Aux[WeightK] - ): LinearInputsChecker[Inputs] = { + Inputs, + InputsRec <: HList, + LabelK <: Symbol, + FeaturesK <: Symbol, + WeightK <: Symbol + ](implicit + i0: LabelledGeneric.Aux[Inputs, InputsRec], + i1: Length.Aux[InputsRec, Nat._3], + i2: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], + i3: Witness.Aux[FeaturesK], + i4: SelectorByValue.Aux[InputsRec, Double, LabelK], + i5: Witness.Aux[LabelK], + i6: SelectorByValue.Aux[InputsRec, Float, WeightK], + i7: Witness.Aux[WeightK] + ): LinearInputsChecker[Inputs] = { new LinearInputsChecker[Inputs] { val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name - val weightCol: Option[String] = Some(implicitly[Witness.Aux[WeightK]].value.name) + val weightCol: Option[String] = Some( + implicitly[Witness.Aux[WeightK]].value.name + ) } } diff --git a/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala b/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala index 9a67d5299..7292ff855 100644 --- a/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala +++ b/ml/src/main/scala/frameless/ml/internals/SelectorByValue.scala @@ -3,24 +3,35 @@ package ml package internals import shapeless.labelled.FieldType -import shapeless.{::, DepFn1, HList, Witness} +import shapeless.{ ::, DepFn1, HList, Witness } /** - * Typeclass supporting record selection by value type (returning the first key whose value is of type `Value`) - */ -trait SelectorByValue[L <: HList, Value] extends DepFn1[L] with Serializable { type Out <: Symbol } + * Typeclass supporting record selection by value type (returning the first key whose value is of type `Value`) + */ +trait SelectorByValue[L <: HList, Value] extends DepFn1[L] with Serializable { + type Out <: Symbol +} object SelectorByValue { - type Aux[L <: HList, Value, Out0 <: Symbol] = SelectorByValue[L, Value] { type Out = Out0 } - implicit def select[K <: Symbol, T <: HList, Value](implicit wk: Witness.Aux[K]): Aux[FieldType[K, Value] :: T, Value, K] = { + type Aux[L <: HList, Value, Out0 <: Symbol] = SelectorByValue[L, Value] { + type Out = Out0 + } + + implicit def select[K <: Symbol, T <: HList, Value]( + implicit + wk: Witness.Aux[K] + ): Aux[FieldType[K, Value] :: T, Value, K] = { new SelectorByValue[FieldType[K, Value] :: T, Value] { type Out = K def apply(l: FieldType[K, Value] :: T): Out = wk.value } } - implicit def recurse[H, T <: HList, Value](implicit st: SelectorByValue[T, Value]): Aux[H :: T, Value, st.Out] = { + implicit def recurse[H, T <: HList, Value]( + implicit + st: SelectorByValue[T, Value] + ): Aux[H :: T, Value, st.Out] = { new SelectorByValue[H :: T, Value] { type Out = st.Out def apply(l: H :: T): Out = st(l.tail) diff --git a/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala index 0fe157654..a7fd7f2c0 100644 --- a/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala +++ b/ml/src/main/scala/frameless/ml/internals/TreesInputsChecker.scala @@ -3,14 +3,14 @@ package ml package internals import shapeless.ops.hlist.Length -import shapeless.{HList, LabelledGeneric, Nat, Witness} +import shapeless.{ HList, LabelledGeneric, Nat, Witness } import org.apache.spark.ml.linalg._ import scala.annotation.implicitNotFound /** - * Can be used for all tree-based ML algorithm (decision tree, random forest, gradient-boosted trees) - */ + * Can be used for all tree-based ML algorithm (decision tree, random forest, gradient-boosted trees) + */ @implicitNotFound( msg = "Cannot prove that ${Inputs} is a valid input type. " + "Input type must only contain a field of type Double (the label) and a field of type " + @@ -24,18 +24,18 @@ trait TreesInputsChecker[Inputs] { object TreesInputsChecker { implicit def checkTreesInputs[ - Inputs, - InputsRec <: HList, - LabelK <: Symbol, - FeaturesK <: Symbol]( - implicit - i0: LabelledGeneric.Aux[Inputs, InputsRec], - i1: Length.Aux[InputsRec, Nat._2], - i2: SelectorByValue.Aux[InputsRec, Double, LabelK], - i3: Witness.Aux[LabelK], - i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], - i5: Witness.Aux[FeaturesK] - ): TreesInputsChecker[Inputs] = { + Inputs, + InputsRec <: HList, + LabelK <: Symbol, + FeaturesK <: Symbol + ](implicit + i0: LabelledGeneric.Aux[Inputs, InputsRec], + i1: Length.Aux[InputsRec, Nat._2], + i2: SelectorByValue.Aux[InputsRec, Double, LabelK], + i3: Witness.Aux[LabelK], + i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], + i5: Witness.Aux[FeaturesK] + ): TreesInputsChecker[Inputs] = { new TreesInputsChecker[Inputs] { val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name diff --git a/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala index 56dfc9a57..51b8e1306 100644 --- a/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala +++ b/ml/src/main/scala/frameless/ml/internals/UnaryInputsChecker.scala @@ -3,13 +3,13 @@ package ml package internals import shapeless.ops.hlist.Length -import shapeless.{HList, LabelledGeneric, Nat, Witness} +import shapeless.{ HList, LabelledGeneric, Nat, Witness } import scala.annotation.implicitNotFound /** - * Can be used for all unary transformers (i.e almost all of them) - */ + * Can be used for all unary transformers (i.e almost all of them) + */ @implicitNotFound( msg = "Cannot prove that ${Inputs} is a valid input type. Input type must have only one field of type ${Expected}" ) @@ -19,15 +19,19 @@ trait UnaryInputsChecker[Inputs, Expected] { object UnaryInputsChecker { - implicit def checkUnaryInputs[Inputs, Expected, InputsRec <: HList, InputK <: Symbol]( - implicit - i0: LabelledGeneric.Aux[Inputs, InputsRec], - i1: Length.Aux[InputsRec, Nat._1], - i2: SelectorByValue.Aux[InputsRec, Expected, InputK], - i3: Witness.Aux[InputK] - ): UnaryInputsChecker[Inputs, Expected] = new UnaryInputsChecker[Inputs, Expected] { - val inputCol: String = implicitly[Witness.Aux[InputK]].value.name - } + implicit def checkUnaryInputs[ + Inputs, + Expected, + InputsRec <: HList, + InputK <: Symbol + ](implicit + i0: LabelledGeneric.Aux[Inputs, InputsRec], + i1: Length.Aux[InputsRec, Nat._1], + i2: SelectorByValue.Aux[InputsRec, Expected, InputK], + i3: Witness.Aux[InputK] + ): UnaryInputsChecker[Inputs, Expected] = + new UnaryInputsChecker[Inputs, Expected] { + val inputCol: String = implicitly[Witness.Aux[InputK]].value.name + } } - diff --git a/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala index e993d9a55..13f199de0 100644 --- a/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala +++ b/ml/src/main/scala/frameless/ml/internals/VectorInputsChecker.scala @@ -3,7 +3,7 @@ package ml package internals import shapeless.ops.hlist.Length -import shapeless.{HList, LabelledGeneric, Nat, Witness} +import shapeless.{ HList, LabelledGeneric, Nat, Witness } import scala.annotation.implicitNotFound import org.apache.spark.ml.linalg.Vector @@ -18,15 +18,19 @@ trait VectorInputsChecker[Inputs] { } object VectorInputsChecker { - implicit def checkVectorInput[Inputs, InputsRec <: HList, FeaturesK <: Symbol]( - implicit + + implicit def checkVectorInput[ + Inputs, + InputsRec <: HList, + FeaturesK <: Symbol + ](implicit i0: LabelledGeneric.Aux[Inputs, InputsRec], i1: Length.Aux[InputsRec, Nat._1], i2: SelectorByValue.Aux[InputsRec, Vector, FeaturesK], i3: Witness.Aux[FeaturesK] ): VectorInputsChecker[Inputs] = { - new VectorInputsChecker[Inputs] { - val featuresCol: String = i3.value.name - } + new VectorInputsChecker[Inputs] { + val featuresCol: String = i3.value.name } + } } diff --git a/ml/src/main/scala/frameless/ml/package.scala b/ml/src/main/scala/frameless/ml/package.scala index d1c306158..8478c1332 100644 --- a/ml/src/main/scala/frameless/ml/package.scala +++ b/ml/src/main/scala/frameless/ml/package.scala @@ -2,12 +2,14 @@ package frameless import org.apache.spark.sql.FramelessInternals.UserDefinedType import org.apache.spark.ml.FramelessInternals -import org.apache.spark.ml.linalg.{Matrix, Vector} +import org.apache.spark.ml.linalg.{ Matrix, Vector } package object ml { - implicit val mlVectorUdt: UserDefinedType[Vector] = FramelessInternals.vectorUdt + implicit val mlVectorUdt: UserDefinedType[Vector] = + FramelessInternals.vectorUdt - implicit val mlMatrixUdt: UserDefinedType[Matrix] = FramelessInternals.matrixUdt + implicit val mlMatrixUdt: UserDefinedType[Matrix] = + FramelessInternals.matrixUdt } diff --git a/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala b/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala index b3c023735..c1bca8e58 100644 --- a/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala +++ b/ml/src/main/scala/frameless/ml/params/kmeans/KMeansInitMode.scala @@ -4,14 +4,14 @@ package params package kmeans /** - * Param for the initialization algorithm. - * This can be either "random" to choose random points as - * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ - * (Bahmani et al., Scalable K-Means++, VLDB 2012). - * Default: k-means||. - */ + * Param for the initialization algorithm. + * This can be either "random" to choose random points as + * initial cluster centers, or "k-means||" to use a parallel variant of k-means++ + * (Bahmani et al., Scalable K-Means++, VLDB 2012). + * Default: k-means||. + */ -sealed abstract class KMeansInitMode private[ml](val sparkValue: String) +sealed abstract class KMeansInitMode private[ml] (val sparkValue: String) object KMeansInitMode { case object Random extends KMeansInitMode("random") diff --git a/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala b/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala index 4b9ca6d4e..0dbd896c4 100644 --- a/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala +++ b/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala @@ -2,15 +2,17 @@ package frameless package ml package params package linears + /** - * SquaredError measures the average of the squares of the errors—that is, - * the average squared difference between the estimated values and what is estimated. - * - * Huber Loss loss function less sensitive to outliers in data than the - * squared error loss - */ -sealed abstract class LossStrategy private[ml](val sparkValue: String) + * SquaredError measures the average of the squares of the errors—that is, + * the average squared difference between the estimated values and what is estimated. + * + * Huber Loss loss function less sensitive to outliers in data than the + * squared error loss + */ +sealed abstract class LossStrategy private[ml] (val sparkValue: String) + object LossStrategy { case object SquaredError extends LossStrategy("squaredError") - case object Huber extends LossStrategy("huber") + case object Huber extends LossStrategy("huber") } diff --git a/ml/src/main/scala/frameless/ml/params/linears/Solver.scala b/ml/src/main/scala/frameless/ml/params/linears/Solver.scala index 277e06e7a..7b14b8b95 100644 --- a/ml/src/main/scala/frameless/ml/params/linears/Solver.scala +++ b/ml/src/main/scala/frameless/ml/params/linears/Solver.scala @@ -4,22 +4,22 @@ package params package linears /** - * solver algorithm used for optimization. - * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton - * optimization method. - * - "normal" denotes using Normal Equation as an analytical solution to the linear regression - * problem. This solver is limited to `LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER`. - * - "auto" (default) means that the solver algorithm is selected automatically. - * The Normal Equations solver will be used when possible, but this will automatically fall - * back to iterative optimization methods when needed. - * - * spark - */ + * solver algorithm used for optimization. + * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton + * optimization method. + * - "normal" denotes using Normal Equation as an analytical solution to the linear regression + * problem. This solver is limited to `LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER`. + * - "auto" (default) means that the solver algorithm is selected automatically. + * The Normal Equations solver will be used when possible, but this will automatically fall + * back to iterative optimization methods when needed. + * + * spark + */ + +sealed abstract class Solver private[ml] (val sparkValue: String) -sealed abstract class Solver private[ml](val sparkValue: String) object Solver { - case object LBFGS extends Solver("l-bfgs") - case object Auto extends Solver("auto") - case object Normal extends Solver("normal") + case object LBFGS extends Solver("l-bfgs") + case object Auto extends Solver("auto") + case object Normal extends Solver("normal") } - diff --git a/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala b/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala index f2167f983..67f2ddc8d 100644 --- a/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala +++ b/ml/src/main/scala/frameless/ml/params/trees/FeatureSubsetStrategy.scala @@ -2,32 +2,34 @@ package frameless package ml package params package trees + /** - * The number of features to consider for splits at each tree node. - * Supported options: - * - Auto: Choose automatically for task: - * If numTrees == 1, set to All - * If numTrees > 1 (forest), set to Sqrt for classification and - * to OneThird for regression. - * - All: use all features - * - OneThird: use 1/3 of the features - * - Sqrt: use sqrt(number of features) - * - Log2: use log2(number of features) - * - Ratio: use (ratio * number of features) features - * - NumberOfFeatures: use numberOfFeatures features. - * (default = Auto) - * - * These various settings are based on the following references: - * - log2: tested in Breiman (2001) - * - sqrt: recommended by Breiman manual for random forests - * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest - * package. - * - * @see Breiman (2001) - * @see - * Breiman manual for random forests - */ -sealed abstract class FeatureSubsetStrategy private[ml](val sparkValue: String) + * The number of features to consider for splits at each tree node. + * Supported options: + * - Auto: Choose automatically for task: + * If numTrees == 1, set to All + * If numTrees > 1 (forest), set to Sqrt for classification and + * to OneThird for regression. + * - All: use all features + * - OneThird: use 1/3 of the features + * - Sqrt: use sqrt(number of features) + * - Log2: use log2(number of features) + * - Ratio: use (ratio * number of features) features + * - NumberOfFeatures: use numberOfFeatures features. + * (default = Auto) + * + * These various settings are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * + * @see Breiman (2001) + * @see + * Breiman manual for random forests + */ +sealed abstract class FeatureSubsetStrategy private[ml] (val sparkValue: String) + object FeatureSubsetStrategy { case object Auto extends FeatureSubsetStrategy("auto") case object All extends FeatureSubsetStrategy("all") @@ -35,5 +37,7 @@ object FeatureSubsetStrategy { case object Sqrt extends FeatureSubsetStrategy("sqrt") case object Log2 extends FeatureSubsetStrategy("log2") case class Ratio(value: Double) extends FeatureSubsetStrategy(value.toString) - case class NumberOfFeatures(value: Int) extends FeatureSubsetStrategy(value.toString) -} \ No newline at end of file + + case class NumberOfFeatures(value: Int) + extends FeatureSubsetStrategy(value.toString) +} diff --git a/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala b/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala index 3b3208623..a1eeb169c 100644 --- a/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala +++ b/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala @@ -3,38 +3,66 @@ package ml package regression import frameless.ml.internals.LinearInputsChecker -import frameless.ml.params.linears.{LossStrategy, Solver} -import frameless.ml.{AppendTransformer, TypedEstimator} -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} +import frameless.ml.params.linears.{ LossStrategy, Solver } +import frameless.ml.{ AppendTransformer, TypedEstimator } +import org.apache.spark.ml.regression.{ + LinearRegression, + LinearRegressionModel +} /** - * Linear Regression linear approach to modelling the relationship - * between a scalar response (or dependent variable) and one or more explanatory variables - */ -final class TypedLinearRegression [Inputs] private[ml]( - lr: LinearRegression, - labelCol: String, - featuresCol: String, - weightCol: Option[String] -) extends TypedEstimator[Inputs, TypedLinearRegression.Outputs, LinearRegressionModel] { - - val estimatorWithoutWeight : LinearRegression = lr + * Linear Regression linear approach to modelling the relationship + * between a scalar response (or dependent variable) and one or more explanatory variables + */ +final class TypedLinearRegression[Inputs] private[ml] ( + lr: LinearRegression, + labelCol: String, + featuresCol: String, + weightCol: Option[String]) + extends TypedEstimator[ + Inputs, + TypedLinearRegression.Outputs, + LinearRegressionModel + ] { + + val estimatorWithoutWeight: LinearRegression = lr .setLabelCol(labelCol) .setFeaturesCol(featuresCol) .setPredictionCol(AppendTransformer.tempColumnName) - val estimator = if (weightCol.isDefined) estimatorWithoutWeight.setWeightCol(weightCol.get) else estimatorWithoutWeight + val estimator = + if (weightCol.isDefined) estimatorWithoutWeight.setWeightCol(weightCol.get) + else estimatorWithoutWeight + + def setRegParam(value: Double): TypedLinearRegression[Inputs] = + copy(lr.setRegParam(value)) + + def setFitIntercept(value: Boolean): TypedLinearRegression[Inputs] = + copy(lr.setFitIntercept(value)) + + def setStandardization(value: Boolean): TypedLinearRegression[Inputs] = + copy(lr.setStandardization(value)) + + def setElasticNetParam(value: Double): TypedLinearRegression[Inputs] = + copy(lr.setElasticNetParam(value)) + + def setMaxIter(value: Int): TypedLinearRegression[Inputs] = + copy(lr.setMaxIter(value)) - def setRegParam(value: Double): TypedLinearRegression[Inputs] = copy(lr.setRegParam(value)) - def setFitIntercept(value: Boolean): TypedLinearRegression[Inputs] = copy(lr.setFitIntercept(value)) - def setStandardization(value: Boolean): TypedLinearRegression[Inputs] = copy(lr.setStandardization(value)) - def setElasticNetParam(value: Double): TypedLinearRegression[Inputs] = copy(lr.setElasticNetParam(value)) - def setMaxIter(value: Int): TypedLinearRegression[Inputs] = copy(lr.setMaxIter(value)) - def setTol(value: Double): TypedLinearRegression[Inputs] = copy(lr.setTol(value)) - def setSolver(value: Solver): TypedLinearRegression[Inputs] = copy(lr.setSolver(value.sparkValue)) - def setAggregationDepth(value: Int): TypedLinearRegression[Inputs] = copy(lr.setAggregationDepth(value)) - def setLoss(value: LossStrategy): TypedLinearRegression[Inputs] = copy(lr.setLoss(value.sparkValue)) - def setEpsilon(value: Double): TypedLinearRegression[Inputs] = copy(lr.setEpsilon(value)) + def setTol(value: Double): TypedLinearRegression[Inputs] = + copy(lr.setTol(value)) + + def setSolver(value: Solver): TypedLinearRegression[Inputs] = + copy(lr.setSolver(value.sparkValue)) + + def setAggregationDepth(value: Int): TypedLinearRegression[Inputs] = + copy(lr.setAggregationDepth(value)) + + def setLoss(value: LossStrategy): TypedLinearRegression[Inputs] = + copy(lr.setLoss(value.sparkValue)) + + def setEpsilon(value: Double): TypedLinearRegression[Inputs] = + copy(lr.setEpsilon(value)) private def copy(newLr: LinearRegression): TypedLinearRegression[Inputs] = new TypedLinearRegression[Inputs](newLr, labelCol, featuresCol, weightCol) @@ -45,8 +73,15 @@ object TypedLinearRegression { case class Outputs(prediction: Double) case class Weight(weight: Double) - - def apply[Inputs](implicit inputsChecker: LinearInputsChecker[Inputs]): TypedLinearRegression[Inputs] = { - new TypedLinearRegression(new LinearRegression(), inputsChecker.labelCol, inputsChecker.featuresCol, inputsChecker.weightCol) + def apply[Inputs]( + implicit + inputsChecker: LinearInputsChecker[Inputs] + ): TypedLinearRegression[Inputs] = { + new TypedLinearRegression( + new LinearRegression(), + inputsChecker.labelCol, + inputsChecker.featuresCol, + inputsChecker.weightCol + ) } -} \ No newline at end of file +} diff --git a/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala b/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala index 69c1ad68c..89586671b 100644 --- a/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala +++ b/ml/src/main/scala/frameless/ml/regression/TypedRandomForestRegressor.scala @@ -4,44 +4,74 @@ package regression import frameless.ml.internals.TreesInputsChecker import frameless.ml.params.trees.FeatureSubsetStrategy -import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} +import org.apache.spark.ml.regression.{ + RandomForestRegressionModel, + RandomForestRegressor +} /** - * Random Forest - * learning algorithm for regression. - * It supports both continuous and categorical features. - */ -final class TypedRandomForestRegressor[Inputs] private[ml]( - rf: RandomForestRegressor, - labelCol: String, - featuresCol: String -) extends TypedEstimator[Inputs, TypedRandomForestRegressor.Outputs, RandomForestRegressionModel] { + * Random Forest + * learning algorithm for regression. + * It supports both continuous and categorical features. + */ +final class TypedRandomForestRegressor[Inputs] private[ml] ( + rf: RandomForestRegressor, + labelCol: String, + featuresCol: String) + extends TypedEstimator[ + Inputs, + TypedRandomForestRegressor.Outputs, + RandomForestRegressionModel + ] { val estimator: RandomForestRegressor = - rf - .setLabelCol(labelCol) + rf.setLabelCol(labelCol) .setFeaturesCol(featuresCol) .setPredictionCol(AppendTransformer.tempColumnName) - def setNumTrees(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setNumTrees(value)) - def setMaxDepth(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMaxDepth(value)) - def setMinInfoGain(value: Double): TypedRandomForestRegressor[Inputs] = copy(rf.setMinInfoGain(value)) - def setMinInstancesPerNode(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMinInstancesPerNode(value)) - def setMaxMemoryInMB(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMaxMemoryInMB(value)) - def setSubsamplingRate(value: Double): TypedRandomForestRegressor[Inputs] = copy(rf.setSubsamplingRate(value)) - def setFeatureSubsetStrategy(value: FeatureSubsetStrategy): TypedRandomForestRegressor[Inputs] = + def setNumTrees(value: Int): TypedRandomForestRegressor[Inputs] = + copy(rf.setNumTrees(value)) + + def setMaxDepth(value: Int): TypedRandomForestRegressor[Inputs] = + copy(rf.setMaxDepth(value)) + + def setMinInfoGain(value: Double): TypedRandomForestRegressor[Inputs] = + copy(rf.setMinInfoGain(value)) + + def setMinInstancesPerNode(value: Int): TypedRandomForestRegressor[Inputs] = + copy(rf.setMinInstancesPerNode(value)) + + def setMaxMemoryInMB(value: Int): TypedRandomForestRegressor[Inputs] = + copy(rf.setMaxMemoryInMB(value)) + + def setSubsamplingRate(value: Double): TypedRandomForestRegressor[Inputs] = + copy(rf.setSubsamplingRate(value)) + + def setFeatureSubsetStrategy( + value: FeatureSubsetStrategy + ): TypedRandomForestRegressor[Inputs] = copy(rf.setFeatureSubsetStrategy(value.sparkValue)) - def setMaxBins(value: Int): TypedRandomForestRegressor[Inputs] = copy(rf.setMaxBins(value)) - private def copy(newRf: RandomForestRegressor): TypedRandomForestRegressor[Inputs] = + def setMaxBins(value: Int): TypedRandomForestRegressor[Inputs] = + copy(rf.setMaxBins(value)) + + private def copy( + newRf: RandomForestRegressor + ): TypedRandomForestRegressor[Inputs] = new TypedRandomForestRegressor[Inputs](newRf, labelCol, featuresCol) } object TypedRandomForestRegressor { case class Outputs(prediction: Double) - def apply[Inputs](implicit inputsChecker: TreesInputsChecker[Inputs]) - : TypedRandomForestRegressor[Inputs] = { - new TypedRandomForestRegressor(new RandomForestRegressor(), inputsChecker.labelCol, inputsChecker.featuresCol) + def apply[Inputs]( + implicit + inputsChecker: TreesInputsChecker[Inputs] + ): TypedRandomForestRegressor[Inputs] = { + new TypedRandomForestRegressor( + new RandomForestRegressor(), + inputsChecker.labelCol, + inputsChecker.featuresCol + ) } -} \ No newline at end of file +} diff --git a/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala b/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala index bec43cd11..6f697ea28 100644 --- a/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala +++ b/ml/src/main/scala/org/apache/spark/ml/FramelessInternals.scala @@ -1,6 +1,6 @@ package org.apache.spark.ml -import org.apache.spark.ml.linalg.{MatrixUDT, VectorUDT} +import org.apache.spark.ml.linalg.{ MatrixUDT, VectorUDT } object FramelessInternals { diff --git a/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala b/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala index de8fcab56..60104fbd2 100644 --- a/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala +++ b/ml/src/test/scala/frameless/ml/FramelessMlSuite.scala @@ -6,7 +6,12 @@ import org.scalatest.BeforeAndAfterAll import org.scalatestplus.scalacheck.Checkers import org.scalatest.funsuite.AnyFunSuite -class FramelessMlSuite extends AnyFunSuite with Checkers with BeforeAndAfterAll with SparkTesting { +class FramelessMlSuite + extends AnyFunSuite + with Checkers + with BeforeAndAfterAll + with SparkTesting { + // Limit size of generated collections and number of checks because Travis implicit override val generatorDrivenConfig = PropertyCheckConfiguration(sizeRange = PosZInt(10), minSize = PosZInt(10)) diff --git a/ml/src/test/scala/frameless/ml/Generators.scala b/ml/src/test/scala/frameless/ml/Generators.scala index f7dde986c..48c0e4bc0 100644 --- a/ml/src/test/scala/frameless/ml/Generators.scala +++ b/ml/src/test/scala/frameless/ml/Generators.scala @@ -1,15 +1,18 @@ package frameless package ml -import frameless.ml.params.linears.{LossStrategy, Solver} +import frameless.ml.params.linears.{ LossStrategy, Solver } import frameless.ml.params.trees.FeatureSubsetStrategy -import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors} -import org.scalacheck.{Arbitrary, Gen} +import org.apache.spark.ml.linalg.{ Matrices, Matrix, Vector, Vectors } +import org.scalacheck.{ Arbitrary, Gen } object Generators { implicit val arbVector: Arbitrary[Vector] = Arbitrary { - val genDenseVector = Gen.listOf(arbDouble.arbitrary).suchThat(_.nonEmpty).map(doubles => Vectors.dense(doubles.toArray)) + val genDenseVector = Gen + .listOf(arbDouble.arbitrary) + .suchThat(_.nonEmpty) + .map(doubles => Vectors.dense(doubles.toArray)) val genSparseVector = genDenseVector.map(_.toSparse) Gen.oneOf(genDenseVector, genSparseVector) @@ -21,29 +24,34 @@ object Generators { nbRows <- Gen.choose(0, size) nbCols <- Gen.choose(1, size) matrix <- { - Gen.listOfN(nbRows * nbCols, arbDouble.arbitrary) + Gen + .listOfN(nbRows * nbCols, arbDouble.arbitrary) .map(values => Matrices.dense(nbRows, nbCols, values.toArray)) } } yield matrix } } - implicit val arbTreesFeaturesSubsetStrategy: Arbitrary[FeatureSubsetStrategy] = Arbitrary { - val genRatio = Gen.choose(0D, 1D).suchThat(_ > 0D).map(FeatureSubsetStrategy.Ratio) - val genNumberOfFeatures = Gen.choose(1, Int.MaxValue).map(FeatureSubsetStrategy.NumberOfFeatures) - - Gen.oneOf(Gen.const(FeatureSubsetStrategy.All), - Gen.const(FeatureSubsetStrategy.All), - Gen.const(FeatureSubsetStrategy.Log2), - Gen.const(FeatureSubsetStrategy.OneThird), - Gen.const(FeatureSubsetStrategy.Sqrt), - genRatio, - genNumberOfFeatures - ) - } + implicit val arbTreesFeaturesSubsetStrategy: Arbitrary[FeatureSubsetStrategy] = + Arbitrary { + val genRatio = + Gen.choose(0D, 1D).suchThat(_ > 0D).map(FeatureSubsetStrategy.Ratio) + val genNumberOfFeatures = + Gen.choose(1, Int.MaxValue).map(FeatureSubsetStrategy.NumberOfFeatures) + + Gen.oneOf( + Gen.const(FeatureSubsetStrategy.All), + Gen.const(FeatureSubsetStrategy.All), + Gen.const(FeatureSubsetStrategy.Log2), + Gen.const(FeatureSubsetStrategy.OneThird), + Gen.const(FeatureSubsetStrategy.Sqrt), + genRatio, + genNumberOfFeatures + ) + } implicit val arbLossStrategy: Arbitrary[LossStrategy] = Arbitrary { - Gen.const(LossStrategy.SquaredError) + Gen.const(LossStrategy.SquaredError) } implicit val arbSolver: Arbitrary[Solver] = Arbitrary { diff --git a/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala b/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala index 0f7e37439..a54e09e42 100644 --- a/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala +++ b/ml/src/test/scala/frameless/ml/TypedEncoderInstancesTests.scala @@ -23,12 +23,15 @@ class TypedEncoderInstancesTests extends FramelessMlSuite { check(prop) } - test("Vector is encoded as VectorUDT and thus can be run in a Spark ML model") { + test( + "Vector is encoded as VectorUDT and thus can be run in a Spark ML model" + ) { case class Input(features: Vector, label: Double) val prop = forAll { trainingData: Matrix => (trainingData.numRows >= 1) ==> { - val inputs = trainingData.rowIter.toVector.map(vector => Input(vector, 0D)) + val inputs = + trainingData.rowIter.toVector.map(vector => Input(vector, 0D)) val inputsDS = TypedDataset.create(inputs) val model = new DecisionTreeRegressor() @@ -39,7 +42,8 @@ class TypedEncoderInstancesTests extends FramelessMlSuite { val randomInput = inputs(Random.nextInt(inputs.length)) val randomInputDS = TypedDataset.create(Seq(randomInput)) - val prediction = trainedModel.transform(randomInputDS.dataset) + val prediction = trainedModel + .transform(randomInputDS.dataset) .select("prediction") .head() .getAs[Double](0) diff --git a/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala b/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala index d98c3b2bf..9fbe2bebc 100644 --- a/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala +++ b/ml/src/test/scala/frameless/ml/classification/ClassificationIntegrationTests.scala @@ -2,7 +2,11 @@ package frameless package ml package classification -import frameless.ml.feature.{TypedIndexToString, TypedStringIndexer, TypedVectorAssembler} +import frameless.ml.feature.{ + TypedIndexToString, + TypedStringIndexer, + TypedVectorAssembler +} import org.apache.spark.ml.linalg.Vector import org.scalatest.matchers.must.Matchers @@ -18,15 +22,26 @@ class ClassificationIntegrationTests extends FramelessMlSuite with Matchers { case class Features(field1: Double, field2: Int) val vectorAssembler = TypedVectorAssembler[Features] - case class DataWithFeatures(field1: Double, field2: Int, field3: String, features: Vector) - val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]() + case class DataWithFeatures( + field1: Double, + field2: Int, + field3: String, + features: Vector) + val dataWithFeatures = + vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]() case class StringIndexerInput(field3: String) val indexer = TypedStringIndexer[StringIndexerInput] val indexerModel = indexer.fit(dataWithFeatures).run() - case class IndexedDataWithFeatures(field1: Double, field2: Int, field3: String, features: Vector, indexedField3: Double) - val indexedData = indexerModel.transform(dataWithFeatures).as[IndexedDataWithFeatures]() + case class IndexedDataWithFeatures( + field1: Double, + field2: Int, + field3: String, + features: Vector, + indexedField3: Double) + val indexedData = + indexerModel.transform(dataWithFeatures).as[IndexedDataWithFeatures]() case class RFInputs(indexedField3: Double, features: Vector) val rf = TypedRandomForestClassifier[RFInputs] @@ -35,38 +50,47 @@ class ClassificationIntegrationTests extends FramelessMlSuite with Matchers { // Prediction - val testData = TypedDataset.create(Seq( - Data(0D, 10, "foo") - )) - val testDataWithFeatures = vectorAssembler.transform(testData).as[DataWithFeatures]() - val indexedTestData = indexerModel.transform(testDataWithFeatures).as[IndexedDataWithFeatures]() + val testData = TypedDataset.create( + Seq( + Data(0D, 10, "foo") + ) + ) + val testDataWithFeatures = + vectorAssembler.transform(testData).as[DataWithFeatures]() + val indexedTestData = + indexerModel.transform(testDataWithFeatures).as[IndexedDataWithFeatures]() case class PredictionInputs(features: Vector, indexedField3: Double) val testInput = indexedTestData.project[PredictionInputs] case class PredictionResultIndexed( - features: Vector, - indexedField3: Double, - rawPrediction: Vector, - probability: Vector, - predictedField3Indexed: Double - ) + features: Vector, + indexedField3: Double, + rawPrediction: Vector, + probability: Vector, + predictedField3Indexed: Double) val predictionDs = model.transform(testInput).as[PredictionResultIndexed]() case class IndexToStringInput(predictedField3Indexed: Double) - val indexToString = TypedIndexToString[IndexToStringInput](indexerModel.transformer.labelsArray.flatten) - - case class PredictionResult( - features: Vector, - indexedField3: Double, - rawPrediction: Vector, - probability: Vector, - predictedField3Indexed: Double, - predictedField3: String + val indexToString = TypedIndexToString[IndexToStringInput]( + indexerModel.transformer.labelsArray.flatten ) - val stringPredictionDs = indexToString.transform(predictionDs).as[PredictionResult]() - val prediction = stringPredictionDs.select(stringPredictionDs.col('predictedField3)).collect().run().toList + case class PredictionResult( + features: Vector, + indexedField3: Double, + rawPrediction: Vector, + probability: Vector, + predictedField3Indexed: Double, + predictedField3: String) + val stringPredictionDs = + indexToString.transform(predictionDs).as[PredictionResult]() + + val prediction = stringPredictionDs + .select(stringPredictionDs.col('predictedField3)) + .collect() + .run() + .toList prediction mustEqual List("foo") } diff --git a/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala b/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala index ab03f1aad..f5571d17a 100644 --- a/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala +++ b/ml/src/test/scala/frameless/ml/classification/TypedRandomForestClassifierTests.scala @@ -5,15 +5,21 @@ package classification import shapeless.test.illTyped import org.apache.spark.ml.linalg._ import frameless.ml.params.trees.FeatureSubsetStrategy -import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.{ Arbitrary, Gen } import org.scalacheck.Prop._ import org.scalatest.matchers.must.Matchers class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers { + implicit val arbDouble: Arbitrary[Double] = - Arbitrary(Gen.choose(1, 99).map(_.toDouble)) // num classes must be between 0 and 100 for the test + Arbitrary( + Gen.choose(1, 99).map(_.toDouble) + ) // num classes must be between 0 and 100 for the test + implicit val arbVectorNonEmpty: Arbitrary[Vector] = - Arbitrary(Generators.arbVector.arbitrary suchThat (_.size > 0)) // vector must not be empty for RandomForestClassifier + Arbitrary( + Generators.arbVector.arbitrary suchThat (_.size > 0) + ) // vector must not be empty for RandomForestClassifier import Generators.arbTreesFeaturesSubsetStrategy test("fit() returns a correct TypedTransformer") { @@ -21,7 +27,8 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers { val rf = TypedRandomForestClassifier[X2[Double, Vector]] val ds = TypedDataset.create(Seq(x2)) val model = rf.fit(ds).run() - val pDs = model.transform(ds).as[X5[Double, Vector, Vector, Vector, Double]]() + val pDs = + model.transform(ds).as[X5[Double, Vector, Vector, Vector, Double]]() pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b) } @@ -30,19 +37,26 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers { val rf = TypedRandomForestClassifier[X2[Vector, Double]] val ds = TypedDataset.create(Seq(x2)) val model = rf.fit(ds).run() - val pDs = model.transform(ds).as[X5[Vector, Double, Vector, Vector, Double]]() + val pDs = + model.transform(ds).as[X5[Vector, Double, Vector, Vector, Double]]() pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b) } - def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] => - val rf = TypedRandomForestClassifier[X2[Vector, Double]] - val ds = TypedDataset.create(Seq(x3)) - val model = rf.fit(ds).run() - val pDs = model.transform(ds).as[X6[Vector, Double, A, Vector, Vector, Double]]() + def prop3[A: TypedEncoder: Arbitrary] = + forAll { x3: X3[Vector, Double, A] => + val rf = TypedRandomForestClassifier[X2[Vector, Double]] + val ds = TypedDataset.create(Seq(x3)) + val model = rf.fit(ds).run() + val pDs = model + .transform(ds) + .as[X6[Vector, Double, A, Vector, Vector, Double]]() - pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect().run() == Seq((x3.a, x3.b, x3.c)) - } + pDs + .select(pDs.col('a), pDs.col('b), pDs.col('c)) + .collect() + .run() == Seq((x3.a, x3.b, x3.c)) + } check(prop) check(prop2) @@ -66,13 +80,13 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers { val model = rf.fit(ds).run() model.transformer.getNumTrees == 10 && - model.transformer.getMaxBins == 100 && - model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue && - model.transformer.getMaxDepth == 10 && - model.transformer.getMaxMemoryInMB == 100 && - model.transformer.getMinInfoGain == 0.1D && - model.transformer.getMinInstancesPerNode == 2 && - model.transformer.getSubsamplingRate == 0.9D + model.transformer.getMaxBins == 100 && + model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue && + model.transformer.getMaxDepth == 10 && + model.transformer.getMaxMemoryInMB == 100 && + model.transformer.getMinInfoGain == 0.1D && + model.transformer.getMinInstancesPerNode == 2 && + model.transformer.getSubsamplingRate == 0.9D } check(prop) @@ -86,4 +100,4 @@ class TypedRandomForestClassifierTests extends FramelessMlSuite with Matchers { illTyped("TypedRandomForestClassifier.create[X2[Vector, String]]()") } -} \ No newline at end of file +} diff --git a/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala b/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala index 976df39b2..c19c42ce3 100644 --- a/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala +++ b/ml/src/test/scala/frameless/ml/clustering/BisectingKMeansTests.scala @@ -2,7 +2,7 @@ package frameless package ml package clustering -import frameless.{TypedDataset, TypedEncoder, X1, X2, X3} +import frameless.{ TypedDataset, TypedEncoder, X1, X2, X3 } import frameless.ml.classification.TypedBisectingKMeans import org.scalacheck.Arbitrary import org.apache.spark.ml.linalg._ @@ -11,6 +11,7 @@ import frameless.ml._ import org.scalatest.matchers.must.Matchers class BisectingKMeansTests extends FramelessMlSuite with Matchers { + implicit val arbVector: Arbitrary[Vector] = Arbitrary(Generators.arbVector.arbitrary) @@ -24,7 +25,7 @@ class BisectingKMeansTests extends FramelessMlSuite with Matchers { pDs.select(pDs.col('a)).collect().run().toList == Seq(x1.a) } - def prop3[A: TypedEncoder : Arbitrary] = forAll { x2: X2[Vector, A] => + def prop3[A: TypedEncoder: Arbitrary] = forAll { x2: X2[Vector, A] => val km = TypedBisectingKMeans[X1[Vector]]() val ds = TypedDataset.create(Seq(x2)) val model = km.fit(ds).run() @@ -44,12 +45,12 @@ class BisectingKMeansTests extends FramelessMlSuite with Matchers { .setMinDivisibleClusterSize(1) .setSeed(123332) - val ds = TypedDataset.create(Seq(X2(Vectors.dense(Array(0D)),0))) + val ds = TypedDataset.create(Seq(X2(Vectors.dense(Array(0D)), 0))) val model = rf.fit(ds).run() - model.transformer.getK == 10 && - model.transformer.getMaxIter == 10 && - model.transformer.getMinDivisibleClusterSize == 1 && - model.transformer.getSeed == 123332 + model.transformer.getK == 10 && + model.transformer.getMaxIter == 10 && + model.transformer.getMinDivisibleClusterSize == 1 && + model.transformer.getSeed == 123332 } } diff --git a/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala b/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala index 398a0963d..a59cc03e3 100644 --- a/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala +++ b/ml/src/test/scala/frameless/ml/clustering/ClusteringIntegrationTests.scala @@ -3,7 +3,7 @@ package ml package clustering import frameless.ml.FramelessMlSuite -import frameless.ml.classification.{TypedBisectingKMeans, TypedKMeans} +import frameless.ml.classification.{ TypedBisectingKMeans, TypedKMeans } import org.apache.spark.ml.linalg.Vector import frameless._ import frameless.ml._ @@ -14,11 +14,13 @@ class ClusteringIntegrationTests extends FramelessMlSuite with Matchers { test("predict field2 from field1 using a K-means clustering") { // Training - val trainingDataDs = TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D,0)) + val trainingDataDs = + TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D, 0)) val vectorAssembler = TypedVectorAssembler[X1[Double]] - val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[X3[Double,Int,Vector]]() + val dataWithFeatures = + vectorAssembler.transform(trainingDataDs).as[X3[Double, Int, Vector]]() case class Input(c: Vector) val km = TypedKMeans[Input].setK(2) @@ -32,22 +34,27 @@ class ClusteringIntegrationTests extends FramelessMlSuite with Matchers { ) val testData = TypedDataset.create(testSeq) - val testDataWithFeatures = vectorAssembler.transform(testData).as[X3[Double,Int,Vector]]() + val testDataWithFeatures = + vectorAssembler.transform(testData).as[X3[Double, Int, Vector]]() - val predictionDs = model.transform(testDataWithFeatures).as[X4[Double,Int,Vector,Int]]() + val predictionDs = + model.transform(testDataWithFeatures).as[X4[Double, Int, Vector, Int]]() - val prediction = predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList + val prediction = + predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList prediction mustEqual testSeq.map(_.b) } test("predict field2 from field1 using a bisecting K-means clustering") { // Training - val trainingDataDs = TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D,0)) + val trainingDataDs = + TypedDataset.create(Seq.fill(5)(X2(10D, 0)) :+ X2(100D, 0)) val vectorAssembler = TypedVectorAssembler[X1[Double]] - val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[X3[Double, Int, Vector]]() + val dataWithFeatures = + vectorAssembler.transform(trainingDataDs).as[X3[Double, Int, Vector]]() case class Inputs(c: Vector) val bkm = TypedBisectingKMeans[Inputs]().setK(2) @@ -61,11 +68,14 @@ class ClusteringIntegrationTests extends FramelessMlSuite with Matchers { ) val testData = TypedDataset.create(testSeq) - val testDataWithFeatures = vectorAssembler.transform(testData).as[X3[Double, Int, Vector]]() + val testDataWithFeatures = + vectorAssembler.transform(testData).as[X3[Double, Int, Vector]]() - val predictionDs = model.transform(testDataWithFeatures).as[X4[Double,Int,Vector,Int]]() + val predictionDs = + model.transform(testDataWithFeatures).as[X4[Double, Int, Vector, Int]]() - val prediction = predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList + val prediction = + predictionDs.select(predictionDs.col[Int]('d)).collect().run().toList prediction mustEqual testSeq.map(_.b) } diff --git a/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala b/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala index a41c1b703..1cb1441b9 100644 --- a/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala +++ b/ml/src/test/scala/frameless/ml/clustering/KMeansTests.scala @@ -3,17 +3,19 @@ package ml package clustering import frameless.ml.classification.TypedKMeans -import frameless.{TypedDataset, TypedEncoder, X1, X2, X3} +import frameless.{ TypedDataset, TypedEncoder, X1, X2, X3 } import org.apache.spark.ml.linalg._ -import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.{ Arbitrary, Gen } import org.scalacheck.Prop._ import frameless.ml._ import frameless.ml.params.kmeans.KMeansInitMode import org.scalatest.matchers.must.Matchers class KMeansTests extends FramelessMlSuite with Matchers { + implicit val arbVector: Arbitrary[Vector] = Arbitrary(Generators.arbVector.arbitrary) + implicit val arbKMeansInitMode: Arbitrary[KMeansInitMode] = Arbitrary { Gen.oneOf( @@ -30,7 +32,7 @@ class KMeansTests extends FramelessMlSuite with Matchers { val dense = Vectors.dense(dubs) vect match { case _: SparseVector => dense.toSparse - case _ => dense + case _ => dense } } @@ -46,17 +48,20 @@ class KMeansTests extends FramelessMlSuite with Matchers { pDs.select(pDs.col('a)).collect().run().toList == Seq(x1.a, x1a.a) } - def prop3[A: TypedEncoder : Arbitrary] = forAll { x2: X2[Vector, A] => + def prop3[A: TypedEncoder: Arbitrary] = forAll { x2: X2[Vector, A] => val x2a = x2.copy(a = newRowWithSameDimension(x2.a)) val km = TypedKMeans[X1[Vector]] val ds = TypedDataset.create(Seq(x2, x2a)) val model = km.fit(ds).run() val pDs = model.transform(ds).as[X3[Vector, A, Int]]() - pDs.select(pDs.col('a), pDs.col('b)).collect().run().toList == Seq((x2.a, x2.b), (x2a.a, x2a.b)) + pDs.select(pDs.col('a), pDs.col('b)).collect().run().toList == Seq( + (x2.a, x2.b), + (x2a.a, x2a.b) + ) } - tolerantRun( _.isInstanceOf[ArrayIndexOutOfBoundsException] ) { + tolerantRun(_.isInstanceOf[ArrayIndexOutOfBoundsException]) { check(prop) check(prop3[Double]) } @@ -76,11 +81,11 @@ class KMeansTests extends FramelessMlSuite with Matchers { val model = rf.fit(ds).run() model.transformer.getInitMode == KMeansInitMode.Random.sparkValue && - model.transformer.getInitSteps == 2 && - model.transformer.getK == 10 && - model.transformer.getMaxIter == 15 && - model.transformer.getSeed == 123223L && - model.transformer.getTol == 12D + model.transformer.getInitSteps == 2 && + model.transformer.getK == 10 && + model.transformer.getMaxIter == 15 && + model.transformer.getSeed == 123223L && + model.transformer.getTol == 12D } check(prop) diff --git a/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala index a27f2966d..8dba04d52 100644 --- a/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala +++ b/ml/src/test/scala/frameless/ml/feature/TypedIndexToStringTests.scala @@ -2,7 +2,7 @@ package frameless package ml package feature -import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.{ Arbitrary, Gen } import org.scalacheck.Prop._ import shapeless.test.illTyped import org.scalatest.matchers.must.Matchers diff --git a/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala index 18d490758..db27983a6 100644 --- a/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala +++ b/ml/src/test/scala/frameless/ml/feature/TypedStringIndexerTests.scala @@ -3,7 +3,7 @@ package ml package feature import frameless.ml.feature.TypedStringIndexer.HandleInvalid -import org.scalacheck.{Arbitrary, Gen} +import org.scalacheck.{ Arbitrary, Gen } import org.scalacheck.Prop._ import shapeless.test.illTyped import org.scalatest.matchers.must.Matchers @@ -11,7 +11,7 @@ import org.scalatest.matchers.must.Matchers class TypedStringIndexerTests extends FramelessMlSuite with Matchers { test(".fit() returns a correct TypedTransformer") { - def prop[A: TypedEncoder : Arbitrary] = forAll { x2: X2[String, A] => + def prop[A: TypedEncoder: Arbitrary] = forAll { x2: X2[String, A] => val indexer = TypedStringIndexer[X1[String]] val ds = TypedDataset.create(Seq(x2)) val model = indexer.fit(ds).run() @@ -30,8 +30,8 @@ class TypedStringIndexerTests extends FramelessMlSuite with Matchers { } val prop = forAll { handleInvalid: HandleInvalid => - val indexer = TypedStringIndexer[X1[String]] - .setHandleInvalid(handleInvalid) + val indexer = + TypedStringIndexer[X1[String]].setHandleInvalid(handleInvalid) val ds = TypedDataset.create(Seq(X1("foo"))) val model = indexer.fit(ds).run() diff --git a/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala b/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala index 52cbc022b..918f4b487 100644 --- a/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala +++ b/ml/src/test/scala/frameless/ml/feature/TypedVectorAssemblerTests.scala @@ -10,23 +10,54 @@ import shapeless.test.illTyped class TypedVectorAssemblerTests extends FramelessMlSuite { test(".transform() returns a correct TypedTransformer") { - def prop[A: TypedEncoder: Arbitrary] = forAll { x5: X5[Int, Long, Double, Boolean, A] => - val assembler = TypedVectorAssembler[X4[Int, Long, Double, Boolean]] - val ds = TypedDataset.create(Seq(x5)) - val ds2 = assembler.transform(ds).as[X6[Int, Long, Double, Boolean, A, Vector]]() - - ds2.collect().run() == - Seq(X6(x5.a, x5.b, x5.c, x5.d, x5.e, Vectors.dense(x5.a.toDouble, x5.b.toDouble, x5.c, if (x5.d) 1D else 0D))) - } - - def prop2[A: TypedEncoder: Arbitrary] = forAll { x5: X5[Boolean, BigDecimal, Byte, Short, A] => - val assembler = TypedVectorAssembler[X4[Boolean, BigDecimal, Byte, Short]] - val ds = TypedDataset.create(Seq(x5)) - val ds2 = assembler.transform(ds).as[X6[Boolean, BigDecimal, Byte, Short, A, Vector]]() - - ds2.collect().run() == - Seq(X6(x5.a, x5.b, x5.c, x5.d, x5.e, Vectors.dense(if (x5.a) 1D else 0D, x5.b.toDouble, x5.c.toDouble, x5.d.toDouble))) - } + def prop[A: TypedEncoder: Arbitrary] = + forAll { x5: X5[Int, Long, Double, Boolean, A] => + val assembler = TypedVectorAssembler[X4[Int, Long, Double, Boolean]] + val ds = TypedDataset.create(Seq(x5)) + val ds2 = assembler + .transform(ds) + .as[X6[Int, Long, Double, Boolean, A, Vector]]() + + ds2.collect().run() == + Seq( + X6( + x5.a, + x5.b, + x5.c, + x5.d, + x5.e, + Vectors + .dense(x5.a.toDouble, x5.b.toDouble, x5.c, if (x5.d) 1D else 0D) + ) + ) + } + + def prop2[A: TypedEncoder: Arbitrary] = + forAll { x5: X5[Boolean, BigDecimal, Byte, Short, A] => + val assembler = + TypedVectorAssembler[X4[Boolean, BigDecimal, Byte, Short]] + val ds = TypedDataset.create(Seq(x5)) + val ds2 = assembler + .transform(ds) + .as[X6[Boolean, BigDecimal, Byte, Short, A, Vector]]() + + ds2.collect().run() == + Seq( + X6( + x5.a, + x5.b, + x5.c, + x5.d, + x5.e, + Vectors.dense( + if (x5.a) 1D else 0D, + x5.b.toDouble, + x5.c.toDouble, + x5.d.toDouble + ) + ) + ) + } check(prop[String]) check(prop[Double]) diff --git a/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala b/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala index b3db83c74..07ebf7745 100644 --- a/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala +++ b/ml/src/test/scala/frameless/ml/regression/RegressionIntegrationTests.scala @@ -18,8 +18,13 @@ class RegressionIntegrationTests extends FramelessMlSuite with Matchers { case class Features(field1: Double, field2: Int) val vectorAssembler = TypedVectorAssembler[Features] - case class DataWithFeatures(field1: Double, field2: Int, field3: Double, features: Vector) - val dataWithFeatures = vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]() + case class DataWithFeatures( + field1: Double, + field2: Int, + field3: Double, + features: Vector) + val dataWithFeatures = + vectorAssembler.transform(trainingDataDs).as[DataWithFeatures]() case class RFInputs(field3: Double, features: Vector) val rf = TypedRandomForestRegressor[RFInputs] @@ -28,15 +33,28 @@ class RegressionIntegrationTests extends FramelessMlSuite with Matchers { // Prediction - val testData = TypedDataset.create(Seq( - Data(0D, 10, 0D) - )) - val testDataWithFeatures = vectorAssembler.transform(testData).as[DataWithFeatures]() - - case class PredictionResult(field1: Double, field2: Int, field3: Double, features: Vector, predictedField3: Double) - val predictionDs = model.transform(testDataWithFeatures).as[PredictionResult]() - - val prediction = predictionDs.select(predictionDs.col('predictedField3)).collect().run().toList + val testData = TypedDataset.create( + Seq( + Data(0D, 10, 0D) + ) + ) + val testDataWithFeatures = + vectorAssembler.transform(testData).as[DataWithFeatures]() + + case class PredictionResult( + field1: Double, + field2: Int, + field3: Double, + features: Vector, + predictedField3: Double) + val predictionDs = + model.transform(testDataWithFeatures).as[PredictionResult]() + + val prediction = predictionDs + .select(predictionDs.col('predictedField3)) + .collect() + .run() + .toList prediction mustEqual List(0D) } diff --git a/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala b/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala index b864b1533..5fbccb087 100644 --- a/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala +++ b/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala @@ -2,7 +2,7 @@ package frameless package ml package regression -import frameless.ml.params.linears.{LossStrategy, Solver} +import frameless.ml.params.linears.{ LossStrategy, Solver } import org.apache.spark.ml.linalg._ import org.scalacheck.Arbitrary import org.scalacheck.Prop._ @@ -11,7 +11,9 @@ import shapeless.test.illTyped class TypedLinearRegressionTests extends FramelessMlSuite with Matchers { - implicit val arbVectorNonEmpty: Arbitrary[Vector] = Arbitrary(Generators.arbVector.arbitrary) + implicit val arbVectorNonEmpty: Arbitrary[Vector] = Arbitrary( + Generators.arbVector.arbitrary + ) test("fit() returns a correct TypedTransformer") { val prop = forAll { x2: X2[Double, Vector] => @@ -32,14 +34,18 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers { pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b) } - def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] => - val lr = TypedLinearRegression[X2[Vector, Double]] - val ds = TypedDataset.create(Seq(x3)) - val model = lr.fit(ds).run() - val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]() + def prop3[A: TypedEncoder: Arbitrary] = + forAll { x3: X3[Vector, Double, A] => + val lr = TypedLinearRegression[X2[Vector, Double]] + val ds = TypedDataset.create(Seq(x3)) + val model = lr.fit(ds).run() + val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]() - pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect().run() == Seq((x3.a, x3.b, x3.c)) - } + pDs + .select(pDs.col('a), pDs.col('b), pDs.col('c)) + .collect() + .run() == Seq((x3.a, x3.b, x3.c)) + } check(prop) check(prop2) @@ -48,7 +54,7 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers { } test("param setting is retained") { - import Generators.{arbLossStrategy, arbSolver} + import Generators.{ arbLossStrategy, arbSolver } val prop = forAll { (lossStrategy: LossStrategy, solver: Solver) => val lr = TypedLinearRegression[X2[Double, Vector]] @@ -66,12 +72,12 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers { val model = lr.fit(ds).run() model.transformer.getAggregationDepth == 10 && - model.transformer.getEpsilon == 4.0 && - model.transformer.getLoss == lossStrategy.sparkValue && - model.transformer.getMaxIter == 23 && - model.transformer.getRegParam == 1.2 && - model.transformer.getTol == 2.3 && - model.transformer.getSolver == solver.sparkValue + model.transformer.getEpsilon == 4.0 && + model.transformer.getLoss == lossStrategy.sparkValue && + model.transformer.getMaxIter == 23 && + model.transformer.getRegParam == 1.2 && + model.transformer.getTol == 2.3 && + model.transformer.getSolver == solver.sparkValue } check(prop) @@ -98,25 +104,23 @@ class TypedLinearRegressionTests extends FramelessMlSuite with Matchers { ) val ds2 = Seq( - X3(new DenseVector(Array(1.0)): Vector,2F, 1.0), - X3(new DenseVector(Array(2.0)): Vector,2F, 2.0), - X3(new DenseVector(Array(3.0)): Vector,2F, 3.0), - X3(new DenseVector(Array(4.0)): Vector,2F, 4.0), - X3(new DenseVector(Array(5.0)): Vector,2F, 5.0), - X3(new DenseVector(Array(6.0)): Vector,2F, 6.0) + X3(new DenseVector(Array(1.0)): Vector, 2F, 1.0), + X3(new DenseVector(Array(2.0)): Vector, 2F, 2.0), + X3(new DenseVector(Array(3.0)): Vector, 2F, 3.0), + X3(new DenseVector(Array(4.0)): Vector, 2F, 4.0), + X3(new DenseVector(Array(5.0)): Vector, 2F, 5.0), + X3(new DenseVector(Array(6.0)): Vector, 2F, 6.0) ) val tds = TypedDataset.create(ds) - val lr = TypedLinearRegression[X2[Vector, Double]] - .setMaxIter(10) + val lr = TypedLinearRegression[X2[Vector, Double]].setMaxIter(10) val model = lr.fit(tds).run() val tds2 = TypedDataset.create(ds2) - val lr2 = TypedLinearRegression[X3[Vector, Float, Double]] - .setMaxIter(10) + val lr2 = TypedLinearRegression[X3[Vector, Float, Double]].setMaxIter(10) val model2 = lr2.fit(tds2).run() diff --git a/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala b/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala index 4a6cd37d2..cc37a6b6e 100644 --- a/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala +++ b/ml/src/test/scala/frameless/ml/regression/TypedRandomForestRegressorTests.scala @@ -10,8 +10,11 @@ import org.scalacheck.Prop._ import org.scalatest.matchers.must.Matchers class TypedRandomForestRegressorTests extends FramelessMlSuite with Matchers { + implicit val arbVectorNonEmpty: Arbitrary[Vector] = - Arbitrary(Generators.arbVector.arbitrary suchThat (_.size > 0)) // vector must not be empty for RandomForestRegressor + Arbitrary( + Generators.arbVector.arbitrary suchThat (_.size > 0) + ) // vector must not be empty for RandomForestRegressor import Generators.arbTreesFeaturesSubsetStrategy test("fit() returns a correct TypedTransformer") { @@ -33,14 +36,18 @@ class TypedRandomForestRegressorTests extends FramelessMlSuite with Matchers { pDs.select(pDs.col('a), pDs.col('b)).collect().run() == Seq(x2.a -> x2.b) } - def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] => - val rf = TypedRandomForestRegressor[X2[Vector, Double]] - val ds = TypedDataset.create(Seq(x3)) - val model = rf.fit(ds).run() - val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]() + def prop3[A: TypedEncoder: Arbitrary] = + forAll { x3: X3[Vector, Double, A] => + val rf = TypedRandomForestRegressor[X2[Vector, Double]] + val ds = TypedDataset.create(Seq(x3)) + val model = rf.fit(ds).run() + val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]() - pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect().run() == Seq((x3.a, x3.b, x3.c)) - } + pDs + .select(pDs.col('a), pDs.col('b), pDs.col('c)) + .collect() + .run() == Seq((x3.a, x3.b, x3.c)) + } check(prop) check(prop2) @@ -64,13 +71,13 @@ class TypedRandomForestRegressorTests extends FramelessMlSuite with Matchers { val model = rf.fit(ds).run() model.transformer.getNumTrees == 10 && - model.transformer.getMaxBins == 100 && - model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue && - model.transformer.getMaxDepth == 10 && - model.transformer.getMaxMemoryInMB == 100 && - model.transformer.getMinInfoGain == 0.1D && - model.transformer.getMinInstancesPerNode == 2 && - model.transformer.getSubsamplingRate == 0.9D + model.transformer.getMaxBins == 100 && + model.transformer.getFeatureSubsetStrategy == featureSubsetStrategy.sparkValue && + model.transformer.getMaxDepth == 10 && + model.transformer.getMaxMemoryInMB == 100 && + model.transformer.getMinInfoGain == 0.1D && + model.transformer.getMinInstancesPerNode == 2 && + model.transformer.getSubsamplingRate == 0.9D } check(prop) diff --git a/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala b/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala index dba59454c..2dacb45d9 100644 --- a/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala +++ b/refined/src/main/scala/frameless/refined/RefinedFieldEncoders.scala @@ -4,7 +4,10 @@ import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.{ - Invoke, NewInstance, UnwrapOption, WrapOption + Invoke, + NewInstance, + UnwrapOption, + WrapOption } import org.apache.spark.sql.types._ @@ -13,15 +16,16 @@ import eu.timepit.refined.api.RefType import frameless.{ TypedEncoder, RecordFieldEncoder } private[refined] trait RefinedFieldEncoders { + /** * @tparam T the refined type (e.g. `String`) */ implicit def optionRefined[F[_, _], T, R]( - implicit + implicit i0: RefType[F], i1: TypedEncoder[T], - i2: ClassTag[F[T, R]], - ): RecordFieldEncoder[Option[F[T, R]]] = + i2: ClassTag[F[T, R]] + ): RecordFieldEncoder[Option[F[T, R]]] = RecordFieldEncoder[Option[F[T, R]]](new TypedEncoder[Option[F[T, R]]] { def nullable = true @@ -54,11 +58,11 @@ private[refined] trait RefinedFieldEncoders { * @tparam T the refined type (e.g. `String`) */ implicit def refined[F[_, _], T, R]( - implicit + implicit i0: RefType[F], i1: TypedEncoder[T], - i2: ClassTag[F[T, R]], - ): RecordFieldEncoder[F[T, R]] = + i2: ClassTag[F[T, R]] + ): RecordFieldEncoder[F[T, R]] = RecordFieldEncoder[F[T, R]](new TypedEncoder[F[T, R]] { def nullable = i1.nullable @@ -76,4 +80,3 @@ private[refined] trait RefinedFieldEncoders { override def toString = s"refined[${i2.runtimeClass.getName}]" }) } - diff --git a/refined/src/main/scala/frameless/refined/package.scala b/refined/src/main/scala/frameless/refined/package.scala index 8819be2bf..2786e14e3 100644 --- a/refined/src/main/scala/frameless/refined/package.scala +++ b/refined/src/main/scala/frameless/refined/package.scala @@ -5,8 +5,9 @@ import scala.reflect.ClassTag import eu.timepit.refined.api.{ RefType, Validate } package object refined extends RefinedFieldEncoders { + implicit def refinedInjection[F[_, _], T, R]( - implicit + implicit refType: RefType[F], validate: Validate[T, R] ): Injection[F[T, R], T] = Injection( @@ -15,19 +16,20 @@ package object refined extends RefinedFieldEncoders { refType.refine[R](value) match { case Left(errMsg) => throw new IllegalArgumentException( - s"Value $value does not satisfy refinement predicate: $errMsg") + s"Value $value does not satisfy refinement predicate: $errMsg" + ) case Right(res) => res } - }) + } + ) implicit def refinedEncoder[F[_, _], T, R]( - implicit + implicit i0: RefType[F], i1: Validate[T, R], i2: TypedEncoder[T], i3: ClassTag[F[T, R]] - ): TypedEncoder[F[T, R]] = TypedEncoder.usingInjection( - i3, refinedInjection, i2) + ): TypedEncoder[F[T, R]] = + TypedEncoder.usingInjection(i3, refinedInjection, i2) } - diff --git a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala index 5476284ea..c152abf34 100644 --- a/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala +++ b/refined/src/test/scala/frameless/RefinedFieldEncoderTests.scala @@ -2,7 +2,11 @@ package frameless import org.apache.spark.sql.Row import org.apache.spark.sql.types.{ - IntegerType, ObjectType, StringType, StructField, StructType + IntegerType, + ObjectType, + StringType, + StructField, + StructType } import org.scalatest.matchers.should.Matchers @@ -24,7 +28,8 @@ class RefinedFieldEncoderTests extends TypedDatasetSuite with Matchers { val nes: NonEmptyString = "Non Empty String" - val unsafeDs = TypedDataset.createUnsafe(sc.parallelize(Seq(nes.value)).toDF())(encoder) + val unsafeDs = + TypedDataset.createUnsafe(sc.parallelize(Seq(nes.value)).toDF())(encoder) val expected = Seq(nes) @@ -40,9 +45,12 @@ class RefinedFieldEncoderTests extends TypedDatasetSuite with Matchers { encoderA.jvmRepr shouldBe ObjectType(classOf[A]) // Check catalystRepr - val expectedAStructType = StructType(Seq( - StructField("a", IntegerType, false), - StructField("s", StringType, false))) + val expectedAStructType = StructType( + Seq( + StructField("a", IntegerType, false), + StructField("s", StringType, false) + ) + ) encoderA.catalystRepr shouldBe expectedAStructType @@ -71,18 +79,23 @@ class RefinedFieldEncoderTests extends TypedDatasetSuite with Matchers { encoderB.jvmRepr shouldBe ObjectType(classOf[B]) // Check catalystRepr - val expectedBStructType = StructType(Seq( - StructField("a", IntegerType, false), - StructField("s", StringType, true))) + val expectedBStructType = StructType( + Seq( + StructField("a", IntegerType, false), + StructField("s", StringType, true) + ) + ) encoderB.catalystRepr shouldBe expectedBStructType // Check unsafe val unsafeDs: TypedDataset[B] = { - val rdd = sc.parallelize(Seq( - Row(bs.a, bs.s.mkString), - Row(2, null.asInstanceOf[String]), - )) + val rdd = sc.parallelize( + Seq( + Row(bs.a, bs.s.mkString), + Row(2, null.asInstanceOf[String]) + ) + ) val df = session.createDataFrame(rdd, expectedBStructType)