Skip to content

Commit

Permalink
Improve test case for In.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Sep 3, 2017
1 parent 099c671 commit 444c64d
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ object Equality {
arguments = """
Arguments:
* expr1, expr2 - the two expressions must be same type or can be casted to a common type,
and must be a type that can be used in equality comparison.
and must be a type that can be used in equality comparison. Map type is not supported.
For complex types such array/struct, the data types of fields must be orderable.
""",
examples = """
Examples:
Expand Down Expand Up @@ -547,7 +548,8 @@ case class EqualTo(left: Expression, right: Expression)
arguments = """
Arguments:
* expr1, expr2 - the two expressions must be same type or can be casted to a common type,
and must be a type that can be used in equality comparison.
and must be a type that can be used in equality comparison. Map type is not supported.
For complex types such array/struct, the data types of fields must be orderable.
""",
examples = """
Examples:
Expand Down Expand Up @@ -593,7 +595,9 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
arguments = """
Arguments:
* expr1, expr2 - the two expressions must be same type or can be casted to a common type,
and must be a type that can be ordered/compared.
and must be a type that can be ordered. For example, map type is not orderable, so it
is not supported. For complex types such array/struct, the data types of fields must
be orderable.
""",
examples = """
Examples:
Expand Down Expand Up @@ -621,7 +625,9 @@ case class LessThan(left: Expression, right: Expression)
arguments = """
Arguments:
* expr1, expr2 - the two expressions must be same type or can be casted to a common type,
and must be a type that can be ordered/compared.
and must be a type that can be ordered. For example, map type is not orderable, so it
is not supported. For complex types such array/struct, the data types of fields must
be orderable.
""",
examples = """
Examples:
Expand Down Expand Up @@ -649,7 +655,9 @@ case class LessThanOrEqual(left: Expression, right: Expression)
arguments = """
Arguments:
* expr1, expr2 - the two expressions must be same type or can be casted to a common type,
and must be a type that can be ordered/compared.
and must be a type that can be ordered. For example, map type is not orderable, so it
is not supported. For complex types such array/struct, the data types of fields must
be orderable.
""",
examples = """
Examples:
Expand Down Expand Up @@ -677,7 +685,9 @@ case class GreaterThan(left: Expression, right: Expression)
arguments = """
Arguments:
* expr1, expr2 - the two expressions must be same type or can be casted to a common type,
and must be a type that can be ordered/compared.
and must be a type that can be ordered. For example, map type is not orderable, so it
is not supported. For complex types such array/struct, the data types of fields must
be orderable.
""",
examples = """
Examples:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, false, null) ::
(null, null, null) :: Nil)

test("IN") {
test("basic IN predicate test") {
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
Expand Down Expand Up @@ -151,29 +151,63 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)

val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.foreach { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
case d: Double if d.isNaN => 0.0d
case f: Float if f.isNaN => 0.0f
case _ => value
}

test("IN with different types") {
def testWithRandomDataGeneration(dataType: DataType, nullable: Boolean): Unit = {
val dataGen = RandomDataGenerator.forType(dataType, nullable = nullable)
if (dataGen.isDefined) {
val inputData = Seq.fill(10) {
val value = dataGen.get.apply()
value match {
case d: Double if d.isNaN => 0.0d
case f: Float if f.isNaN => 0.0f
case _ => value
}
}
val input = inputData.map(NonFoldableLiteral.create(_, dataType))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
true
} else if (inputData.slice(1, 10).contains(null)) {
null
} else {
false
}
checkEvaluation(In(input(0), input.slice(1, 10)), expected)
}
val input = inputData.map(NonFoldableLiteral.create(_, t))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
true
} else if (inputData.slice(1, 10).contains(null)) {
null
} else {
false
}
checkEvaluation(In(input(0), input.slice(1, 10)), expected)
}

val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t =>
RandomDataGenerator.forType(t).isDefined && !t.isInstanceOf[DecimalType]
} ++ Seq(DecimalType.USER_DEFAULT)

val atomicArrayTypes = atomicTypes.map(ArrayType(_, containsNull = true))

// Basic types:
for (
dataType <- atomicTypes;
nullable <- Seq(true, false)) {
testWithRandomDataGeneration(dataType, nullable)
}

// Array types:
for (
arrayType <- atomicArrayTypes;
nullable <- Seq(true, false)
if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined) {
testWithRandomDataGeneration(arrayType, nullable)
}

// Struct types:
for (
colOneType <- atomicTypes;
colTwoType <- atomicTypes;
nullable <- Seq(true, false)) {
val structType = StructType(
StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil)
testWithRandomDataGeneration(structType, nullable)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
-- In
select 1 in(1, 2, 3);
select 1 in(2, 3, 4);
select named_struct('a', 1, 'b', 2) in(named_struct('a', 1, 'b', 1), named_struct('a', 1, 'b', 3));
select named_struct('a', 1, 'b', 2) in(named_struct('a', 1, 'b', 2), named_struct('a', 1, 'b', 3));

-- EqualTo
select 1 = 1;
select 1 = '1';
Expand Down
Loading

0 comments on commit 444c64d

Please sign in to comment.