From cf29fc5f931e2a650eaeb6b4c08ed6e457f1b073 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Tue, 1 Aug 2017 15:52:03 -0500 Subject: [PATCH 1/8] should work needs tests --- .../sql/catalyst/expressions/predicates.scala | 56 +++++-------------- .../spark/sql/types/AbstractDataType.scala | 12 ---- .../ExpressionTypeCheckingSuite.scala | 16 +++--- 3 files changed, 22 insertions(+), 62 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7bf10f199f1c7..ca5c731661988 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -453,6 +453,14 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { + override def inputType: AbstractDataType = AnyDataType + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName) + case failure => failure + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator @@ -465,7 +473,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } } - protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) + protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(left.dataType) } @@ -483,28 +491,13 @@ object Equality { } } +// TODO: although map type is not orderable, technically map type should be able to be used +// in equality comparison @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.") case class EqualTo(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = AnyDataType - - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - // TODO: although map type is not orderable, technically map type should be able to be used - // in equality comparison, remove this type check once we support it. - if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " + - s"input type is ${left.dataType.catalogString}.") - } else { - TypeCheckResult.TypeCheckSuccess - } - case failure => failure - } - } - override def symbol: String = "=" protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right) @@ -514,6 +507,8 @@ case class EqualTo(left: Expression, right: Expression) } } +// TODO: although map type is not orderable, technically map type should be able to be used +// in equality comparison @ExpressionDescription( usage = """ expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for non-null operands, @@ -521,23 +516,6 @@ case class EqualTo(left: Expression, right: Expression) """) case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - override def inputType: AbstractDataType = AnyDataType - - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - // TODO: although map type is not orderable, technically map type should be able to be used - // in equality comparison, remove this type check once we support it. - if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " + - s"input type is ${left.dataType.catalogString}.") - } else { - TypeCheckResult.TypeCheckSuccess - } - case failure => failure - } - } - override def symbol: String = "<=>" override def nullable: Boolean = false @@ -569,8 +547,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = "<" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) @@ -581,8 +557,6 @@ case class LessThan(left: Expression, right: Expression) case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = "<=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) @@ -593,8 +567,6 @@ case class LessThanOrEqual(left: Expression, right: Expression) case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = ">" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) @@ -605,8 +577,6 @@ case class GreaterThan(left: Expression, right: Expression) case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = ">=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 1d54ff5825c2e..3041f44b116ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -78,18 +78,6 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { - /** - * Types that can be ordered/compared. In the long run we should probably make this a trait - * that can be mixed into each data type, and perhaps create an `AbstractDataType`. - */ - // TODO: Should we consolidate this with RowOrdering.isOrderable? - val Ordered = TypeCollection( - BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType, - TimestampType, DateType, - StringType, BinaryType) - /** * Types that include numeric types and interval type. They are only used in unary_minus, * unary_positive, add and subtract operations. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 30725773a37b1..1044fee683205 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{LongType, StringType, TypeCollection} +import org.apache.spark.sql.types._ class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -109,16 +110,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo") - assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") + assertError(EqualNullSafe('mapField, 'mapField), + "EqualNullSafe does not support ordering on type MapType") assertError(LessThan('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "LessThan does not support ordering on type MapType") assertError(LessThanOrEqual('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "LessThanOrEqual does not support ordering on type MapType") assertError(GreaterThan('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "GreaterThan does not support ordering on type MapType") assertError(GreaterThanOrEqual('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "GreaterThanOrEqual does not support ordering on type MapType") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") From 0d1fd568f2af298bfb72ed8ca2f2560fa935b6f6 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 2 Aug 2017 09:57:54 -0500 Subject: [PATCH 2/8] update unit test with array and struct types --- .../sql/catalyst/expressions/PredicateSuite.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index ef510a95ef446..32d4dbc7b78a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -215,14 +215,21 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) + case class MyStruct(a: Long, b: String) + + private val smallValues = + Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false, Array(1L, 2L)) + .map(Literal(_)) :+ Literal.create(MyStruct(1L, "b")) private val largeValues = - Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_)) + Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true, Array(2L, 1L)) + .map(Literal(_)) :+ Literal.create(MyStruct(2L, "a")) private val equalValues1 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true, Array(1L, 2L)) + .map(Literal(_)) :+ Literal.create(MyStruct(1L, "a")) private val equalValues2 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true, Array(1L, 2L)) + .map(Literal(_)) :+ Literal.create(MyStruct(1L, "a")) test("BinaryComparison consistency check") { DataTypeTestUtils.ordered.foreach { dt => From d1c7565fafa4a1a3ef411ff8c7ebe498a01c4f51 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 2 Aug 2017 10:22:15 -0500 Subject: [PATCH 3/8] fix style --- .../sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 1044fee683205..36714bd631b0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ From ec8dc950874c9c5864120e27b3e3fd01d1a3b28e Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 2 Aug 2017 10:25:02 -0500 Subject: [PATCH 4/8] make MyStruct private --- .../apache/spark/sql/catalyst/expressions/PredicateSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 32d4dbc7b78a2..980ba91f0b13b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -215,7 +215,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } - case class MyStruct(a: Long, b: String) + private case class MyStruct(a: Long, b: String) private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false, Array(1L, 2L)) From caf74bf316c35da68f9c0ec6c1d6eaf75b4e5eb1 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Wed, 2 Aug 2017 12:45:46 -0500 Subject: [PATCH 5/8] fix test --- .../apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4e0613619add6..884e113537c93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -505,7 +505,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) } test("PredicateSubQuery is used outside of a filter") { From c4f43e90bd9627300d2df5eee2e9a93042696936 Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Tue, 8 Aug 2017 15:55:35 -0500 Subject: [PATCH 6/8] additional unit tests and comment for inputType --- .../sql/catalyst/expressions/predicates.scala | 2 ++ .../catalyst/expressions/PredicateSuite.scala | 27 +++++++++++++------ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index ca5c731661988..e8f945d1ae682 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -453,6 +453,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { + // Note that we need to give a superset of allowable input types since orderable types are not + // finitely enumerable. The allowable types are checked below by checkInputDataTypes. override def inputType: AbstractDataType = AnyDataType override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 980ba91f0b13b..72b61cd4a981b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite @@ -216,20 +218,29 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } private case class MyStruct(a: Long, b: String) + private case class MyStruct2(a: MyStruct, b: Array[Int]) private val smallValues = - Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false, Array(1L, 2L)) - .map(Literal(_)) :+ Literal.create(MyStruct(1L, "b")) + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1)))) private val largeValues = - Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true, Array(2L, 1L)) - .map(Literal(_)) :+ Literal.create(MyStruct(2L, "a")) + Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2), + new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2)))) private val equalValues1 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true, Array(1L, 2L)) - .map(Literal(_)) :+ Literal.create(MyStruct(1L, "a")) + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1)))) private val equalValues2 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true, Array(1L, 2L)) - .map(Literal(_)) :+ Literal.create(MyStruct(1L, "a")) + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1)))) test("BinaryComparison consistency check") { DataTypeTestUtils.ordered.foreach { dt => From cc2f3eca28ee6b9faa87853568205307567827cc Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Mon, 14 Aug 2017 13:49:38 -0500 Subject: [PATCH 7/8] Fix codegen fix for NullType, ordering for UDT's. Testing for NullType and UDT. --- .../expressions/codegen/CodeGenerator.scala | 1 + .../spark/sql/catalyst/util/TypeUtils.scala | 1 + .../catalyst/expressions/PredicateSuite.scala | 19 ++++++++++++++----- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a014e2aa34820..0344fc6e8762e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,6 +582,7 @@ class CodegenContext { case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) + case NullType => "true" case _ => throw new IllegalArgumentException( "cannot generate equality code for un-comparable type: " + dataType.simpleString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 45225779bffcb..1dcda49a3af6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -65,6 +65,7 @@ object TypeUtils { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 72b61cd4a981b..055c31c2b3018 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -24,7 +24,8 @@ import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -219,28 +220,33 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { private case class MyStruct(a: Long, b: String) private case class MyStruct2(a: MyStruct, b: Array[Int]) + private val udt = new ExamplePointUDT private val smallValues = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), - Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1)))) + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), + Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) private val largeValues = Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2), new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")), - Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2)))) + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))), + Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt)) private val equalValues1 = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), - Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1)))) + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), + Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) private val equalValues2 = Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), - Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1)))) + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), + Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) test("BinaryComparison consistency check") { DataTypeTestUtils.ordered.foreach { dt => @@ -303,11 +309,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { // Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757 val normalInt = Literal(-1) val nullInt = NonFoldableLiteral.create(null, IntegerType) + val nullNullType = Literal.create(null, NullType) def nullTest(op: (Expression, Expression) => Expression): Unit = { checkEvaluation(op(normalInt, nullInt), null) checkEvaluation(op(nullInt, normalInt), null) checkEvaluation(op(nullInt, nullInt), null) + checkEvaluation(op(nullNullType, nullNullType), null) } nullTest(LessThan) @@ -319,6 +327,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(normalInt, nullInt), false) checkEvaluation(EqualNullSafe(nullInt, normalInt), false) checkEvaluation(EqualNullSafe(nullInt, nullInt), true) + checkEvaluation(EqualNullSafe(nullNullType, nullNullType), true) } test("EqualTo on complex type") { From 6e011860ed800c9f869b66674cb241d3bb2d94fc Mon Sep 17 00:00:00 2001 From: Andrew Ray Date: Thu, 31 Aug 2017 14:20:08 -0500 Subject: [PATCH 8/8] true => false --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0344fc6e8762e..a7b68abc2e241 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,7 +582,7 @@ class CodegenContext { case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) - case NullType => "true" + case NullType => "false" case _ => throw new IllegalArgumentException( "cannot generate equality code for un-comparable type: " + dataType.simpleString)