diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a014e2aa34820..a7b68abc2e241 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -582,6 +582,7 @@ class CodegenContext { case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) + case NullType => "false" case _ => throw new IllegalArgumentException( "cannot generate equality code for un-comparable type: " + dataType.simpleString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7bf10f199f1c7..e8f945d1ae682 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -453,6 +453,16 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { + // Note that we need to give a superset of allowable input types since orderable types are not + // finitely enumerable. The allowable types are checked below by checkInputDataTypes. + override def inputType: AbstractDataType = AnyDataType + + override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName) + case failure => failure + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator @@ -465,7 +475,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate { } } - protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType) + protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(left.dataType) } @@ -483,28 +493,13 @@ object Equality { } } +// TODO: although map type is not orderable, technically map type should be able to be used +// in equality comparison @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.") case class EqualTo(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = AnyDataType - - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - // TODO: although map type is not orderable, technically map type should be able to be used - // in equality comparison, remove this type check once we support it. - if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " + - s"input type is ${left.dataType.catalogString}.") - } else { - TypeCheckResult.TypeCheckSuccess - } - case failure => failure - } - } - override def symbol: String = "=" protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right) @@ -514,6 +509,8 @@ case class EqualTo(left: Expression, right: Expression) } } +// TODO: although map type is not orderable, technically map type should be able to be used +// in equality comparison @ExpressionDescription( usage = """ expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for non-null operands, @@ -521,23 +518,6 @@ case class EqualTo(left: Expression, right: Expression) """) case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { - override def inputType: AbstractDataType = AnyDataType - - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckSuccess => - // TODO: although map type is not orderable, technically map type should be able to be used - // in equality comparison, remove this type check once we support it. - if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { - TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " + - s"input type is ${left.dataType.catalogString}.") - } else { - TypeCheckResult.TypeCheckSuccess - } - case failure => failure - } - } - override def symbol: String = "<=>" override def nullable: Boolean = false @@ -569,8 +549,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case class LessThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = "<" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) @@ -581,8 +559,6 @@ case class LessThan(left: Expression, right: Expression) case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = "<=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) @@ -593,8 +569,6 @@ case class LessThanOrEqual(left: Expression, right: Expression) case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = ">" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) @@ -605,8 +579,6 @@ case class GreaterThan(left: Expression, right: Expression) case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { - override def inputType: AbstractDataType = TypeCollection.Ordered - override def symbol: String = ">=" protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 45225779bffcb..1dcda49a3af6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -65,6 +65,7 @@ object TypeUtils { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index 1d54ff5825c2e..3041f44b116ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -78,18 +78,6 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) private[sql] object TypeCollection { - /** - * Types that can be ordered/compared. In the long run we should probably make this a trait - * that can be mixed into each data type, and perhaps create an `AbstractDataType`. - */ - // TODO: Should we consolidate this with RowOrdering.isOrderable? - val Ordered = TypeCollection( - BooleanType, - ByteType, ShortType, IntegerType, LongType, - FloatType, DoubleType, DecimalType, - TimestampType, DateType, - StringType, BinaryType) - /** * Types that include numeric types and interval type. They are only used in unary_minus, * unary_positive, add and subtract operations. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 4e0613619add6..884e113537c93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -505,7 +505,7 @@ class AnalysisErrorSuite extends AnalysisTest { right, joinType = Cross, condition = Some('b === 'd)) - assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil) + assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 30725773a37b1..36714bd631b0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{LongType, StringType, TypeCollection} +import org.apache.spark.sql.types._ class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -109,16 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) - assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo") - assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe") + assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType") + assertError(EqualNullSafe('mapField, 'mapField), + "EqualNullSafe does not support ordering on type MapType") assertError(LessThan('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "LessThan does not support ordering on type MapType") assertError(LessThanOrEqual('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "LessThanOrEqual does not support ordering on type MapType") assertError(GreaterThan('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "GreaterThan does not support ordering on type MapType") assertError(GreaterThanOrEqual('mapField, 'mapField), - s"requires ${TypeCollection.Ordered.simpleString} type") + "GreaterThanOrEqual does not support ordering on type MapType") assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index ef510a95ef446..055c31c2b3018 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import scala.collection.immutable.HashSet import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -215,14 +218,35 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_)) + private case class MyStruct(a: Long, b: String) + private case class MyStruct2(a: MyStruct, b: Array[Int]) + private val udt = new ExamplePointUDT + + private val smallValues = + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), + Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) private val largeValues = - Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_)) + Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2), + new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))), + Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt)) private val equalValues1 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), + Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) private val equalValues2 = - Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_)) + Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1), + new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L)) + .map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")), + Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))), + Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt)) test("BinaryComparison consistency check") { DataTypeTestUtils.ordered.foreach { dt => @@ -285,11 +309,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { // Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757 val normalInt = Literal(-1) val nullInt = NonFoldableLiteral.create(null, IntegerType) + val nullNullType = Literal.create(null, NullType) def nullTest(op: (Expression, Expression) => Expression): Unit = { checkEvaluation(op(normalInt, nullInt), null) checkEvaluation(op(nullInt, normalInt), null) checkEvaluation(op(nullInt, nullInt), null) + checkEvaluation(op(nullNullType, nullNullType), null) } nullTest(LessThan) @@ -301,6 +327,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(EqualNullSafe(normalInt, nullInt), false) checkEvaluation(EqualNullSafe(nullInt, normalInt), false) checkEvaluation(EqualNullSafe(nullInt, nullInt), true) + checkEvaluation(EqualNullSafe(nullNullType, nullNullType), true) } test("EqualTo on complex type") {