Skip to content

Commit

Permalink
add withColumnTyped
Browse files Browse the repository at this point in the history
  • Loading branch information
frosforever committed Nov 16, 2017
1 parent f300c98 commit 835b774
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 1 deletion.
92 changes: 91 additions & 1 deletion dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter}
import org.apache.spark.sql._
import shapeless._
import shapeless.ops.hlist.{Prepend, ToTraversable, Tupler}
import shapeless.ops.hlist.{Align, Prepend, ToTraversable, Tupler}
import shapeless.ops.record.{Keys, Values}

/** [[TypedDataset]] is a safer interface for working with `Dataset`.
*
Expand Down Expand Up @@ -626,6 +627,95 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val

TypedDataset.create[Out](selected)
}

/**
* Adds a column to a Dataset so long as the specified output type, `U`, has
* an extra column from `T` that has type `A`.
*
* @example
* {{{
* case class X(i: Int, j: Int)
* case class Y(i: Int, j: Int, k: Boolean)
* val f: TypedDataset[X] = TypedDataset.create(X(1,1) :: X(1,1) :: X(1,10) :: Nil)
* val fNew: TypedDataset[Y] = f.withColumnTyped[Y](f('j) === 10)
* }}}
* @param ca The typed column to add
* @param uEncder TypeEncoder for output type U
* @param aEncoder TypeEncoder for added column type `A`
* @param tgen the LabelledGeneric derived for T
* @param ugen the LabelledGeneric derived for U
* @param tKeys the keys of T
* @param uKeys the keys of U
* @param tValues the values of T
* @param uValues the values of U
* @param prepend the values of T with column A added
* @param align the values from `prepend` aligned to reforestation of U.
* Also serves as a typed proof that T + A = U
* @param tKeysTraverse allows for traversing the keys of T
* @param uKeysTraverse allows for traversing the keys of U
* @tparam U the output type
* @tparam A The added column type
* @tparam TRep shapeless' record representation of T
* @tparam URep shapeless' record representation of U
* @tparam TKeys the keys of T as an HList
* @tparam UKeys the keys of U as an HList
* @tparam TValues the values of T as an HList
* @tparam UValues the values of U as an HList
* @tparam Unaligned the values of T with A prepended
*
* @see [[frameless.TypedDataset.withColumnApply#apply]]
*/
def withColumnTyped[U] = new withColumnApply[U]

class withColumnApply[U] {
def apply[
A,
TRep <: HList,
URep <: HList,
TKeys <: HList,
UKeys <: HList,
TValues <: HList,
UValues <: HList,
Unaligned <: HList
](
ca : TypedColumn[T, A]
)(implicit
uEncder: TypedEncoder[U],
aEncoder: TypedEncoder[A],

tgen: LabelledGeneric.Aux[T, TRep],
ugen: LabelledGeneric.Aux[U, URep],
tKeys: Keys.Aux[TRep, TKeys],
uKeys: Keys.Aux[URep, UKeys],

tValues: Values.Aux[TRep, TValues],
uValues: Values.Aux[URep, UValues],

prepend: Prepend.Aux[TValues, A :: HNil, Unaligned],
align: Align[Unaligned, UValues],

tKeysTraverse: ToTraversable.Aux[TKeys, Seq, Symbol],
uKeysTraverse: ToTraversable.Aux[UKeys, Seq, Symbol]
) = {
val newNames = uKeys.apply.to[Seq].map(_.name)
val oldNames = tKeys.apply.to[Seq].map(_.name)

//We already know there's only one (or at least there better be)
val newColumnName = newNames.toSet.diff(oldNames.toSet).head

val dfWithNewColumn = dataset
.toDF()
.withColumn(newColumnName, ca.untyped)

val newColumns = newNames.map(dfWithNewColumn.col)

val selected = dfWithNewColumn
.select(newColumns: _*)
.as[U](TypedExpressionEncoder[U])

TypedDataset.create[U](selected)
}
}
}

object TypedDataset {
Expand Down
24 changes: 24 additions & 0 deletions dataset/src/test/scala/frameless/WithColumnTypedTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package frameless

import org.scalacheck.Prop
import org.scalacheck.Prop._

class WithColumnTypedTest extends TypedDatasetSuite {
test("append four columns") {
def prop[A: TypedEncoder](value: A): Prop = {
val d = TypedDataset.create(X1(value) :: Nil)
val d1 = d.withColumnTyped[X2[A, A]](d('a))
val d2 = d1.withColumnTyped[X3[A, A, A]](d1('b))
val d3 = d2.withColumnTyped[X4[A, A, A, A]](d2('c))
val d4 = d3.withColumnTyped[X5[A, A, A, A, A]](d3('d))

X5(value, value, value, value, value) ?= d4.collect().run().head
}

check(prop[Int] _)
check(prop[Long] _)
check(prop[String] _)
check(prop[SQLDate] _)
check(prop[Option[X1[Boolean]]] _)
}
}

0 comments on commit 835b774

Please sign in to comment.