Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-21110][SQL] Structs, arrays, and other orderable datatypes should be usable in inequalities #18818

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,7 @@ class CodegenContext {
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case NullType => "false"
case _ =>
throw new IllegalArgumentException(
"cannot generate equality code for un-comparable type: " + dataType.simpleString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,16 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P

abstract class BinaryComparison extends BinaryOperator with Predicate {

// Note that we need to give a superset of allowable input types since orderable types are not
// finitely enumerable. The allowable types are checked below by checkInputDataTypes.
override def inputType: AbstractDataType = AnyDataType

override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(left.dataType, this.getClass.getSimpleName)
case failure => failure
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (ctx.isPrimitiveType(left.dataType)
&& left.dataType != BooleanType // java boolean doesn't support > or < operator
Expand All @@ -465,7 +475,7 @@ abstract class BinaryComparison extends BinaryOperator with Predicate {
}
}

protected lazy val ordering = TypeUtils.getInterpretedOrdering(left.dataType)
protected lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(left.dataType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The acceptable types in TypeUtils.getInterpretedOrdering are less than RowOrdering.isOrderable. It only accepts AtomicType, ArrayType and StructType.

NullType, UserDefinedType can cause problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As ordering is lazily accessed, and any nulls don't lead us to access it in those predicates,NullType should be safe. We should add related test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed in cc2f3ec

}


Expand All @@ -483,28 +493,13 @@ object Equality {
}
}

// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison
@ExpressionDescription(
usage = "expr1 _FUNC_ expr2 - Returns true if `expr1` equals `expr2`, or false otherwise.")
case class EqualTo(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = AnyDataType

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison, remove this type check once we support it.
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " +
s"input type is ${left.dataType.catalogString}.")
} else {
TypeCheckResult.TypeCheckSuccess
}
case failure => failure
}
}

override def symbol: String = "="

protected override def nullSafeEval(left: Any, right: Any): Any = ordering.equiv(left, right)
Expand All @@ -514,30 +509,15 @@ case class EqualTo(left: Expression, right: Expression)
}
}

// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison
@ExpressionDescription(
usage = """
expr1 _FUNC_ expr2 - Returns same result as the EQUAL(=) operator for non-null operands,
but returns true if both are null, false if one of the them is null.
""")
case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison {

override def inputType: AbstractDataType = AnyDataType

override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
// TODO: although map type is not orderable, technically map type should be able to be used
// in equality comparison, remove this type check once we support it.
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " +
s"input type is ${left.dataType.catalogString}.")
} else {
TypeCheckResult.TypeCheckSuccess
}
case failure => failure
}
}

override def symbol: String = "<=>"

override def nullable: Boolean = false
Expand Down Expand Up @@ -569,8 +549,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
case class LessThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

override def symbol: String = "<"

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2)
Expand All @@ -581,8 +559,6 @@ case class LessThan(left: Expression, right: Expression)
case class LessThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

override def symbol: String = "<="

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2)
Expand All @@ -593,8 +569,6 @@ case class LessThanOrEqual(left: Expression, right: Expression)
case class GreaterThan(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

override def symbol: String = ">"

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2)
Expand All @@ -605,8 +579,6 @@ case class GreaterThan(left: Expression, right: Expression)
case class GreaterThanOrEqual(left: Expression, right: Expression)
extends BinaryComparison with NullIntolerant {

override def inputType: AbstractDataType = TypeCollection.Ordered

override def symbol: String = ">="

protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gteq(input1, input2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ object TypeUtils {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,6 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])

private[sql] object TypeCollection {

/**
* Types that can be ordered/compared. In the long run we should probably make this a trait
* that can be mixed into each data type, and perhaps create an `AbstractDataType`.
*/
// TODO: Should we consolidate this with RowOrdering.isOrderable?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious: Do we need to do anything with RowOrdering.isOrderable given the change in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, RowOrdering.isOrderable (which is used by TypeUtils.checkForOrderingExpr) returns true on a strict superset of this type collection as it works for complex types that need recursive checks.

val Ordered = TypeCollection(
BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType,
TimestampType, DateType,
StringType, BinaryType)

/**
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ class AnalysisErrorSuite extends AnalysisTest {
right,
joinType = Cross,
condition = Some('b === 'd))
assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
assertAnalysisError(plan2, "EqualTo does not support ordering on type MapType" :: Nil)
}

test("PredicateSubQuery is used outside of a filter") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, StringType, TypeCollection}
import org.apache.spark.sql.types._

class ExpressionTypeCheckingSuite extends SparkFunSuite {

Expand Down Expand Up @@ -109,16 +109,17 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))

assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
assertError(EqualTo('mapField, 'mapField), "EqualTo does not support ordering on type MapType")
assertError(EqualNullSafe('mapField, 'mapField),
"EqualNullSafe does not support ordering on type MapType")
assertError(LessThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"LessThan does not support ordering on type MapType")
assertError(LessThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"LessThanOrEqual does not support ordering on type MapType")
assertError(GreaterThan('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"GreaterThan does not support ordering on type MapType")
assertError(GreaterThanOrEqual('mapField, 'mapField),
s"requires ${TypeCollection.Ordered.simpleString} type")
"GreaterThanOrEqual does not support ordering on type MapType")

assertError(If('intField, 'stringField, 'stringField),
"type of predicate expression in If should be boolean")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}

import scala.collection.immutable.HashSet

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -215,14 +218,35 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
private case class MyStruct(a: Long, b: String)
private case class MyStruct2(a: MyStruct, b: Array[Int])
private val udt = new ExamplePointUDT

private val smallValues =
Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
new Timestamp(1), "a", 1f, 1d, 0f, 0d, false, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val largeValues =
Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN, true).map(Literal(_))
Seq(2.toByte, 2.toShort, 2, 2L, Decimal(2), Array(2.toByte), new Date(2000, 1, 2),
new Timestamp(2), "b", 2f, 2d, Float.NaN, Double.NaN, true, Array(2L, 1L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(2L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 2))),
Literal.create(ArrayData.toArrayData(Array(1.0, 3.0)), udt))

private val equalValues1 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))
private val equalValues2 =
Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN, true).map(Literal(_))
Seq(1.toByte, 1.toShort, 1, 1L, Decimal(1), Array(1.toByte), new Date(2000, 1, 1),
new Timestamp(1), "a", 1f, 1d, Float.NaN, Double.NaN, true, Array(1L, 2L))
.map(Literal(_)) ++ Seq(Literal.create(MyStruct(1L, "b")),
Literal.create(MyStruct2(MyStruct(1L, "a"), Array(1, 1))),
Literal.create(ArrayData.toArrayData(Array(1.0, 2.0)), udt))

test("BinaryComparison consistency check") {
DataTypeTestUtils.ordered.foreach { dt =>
Expand Down Expand Up @@ -285,11 +309,13 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
val nullInt = NonFoldableLiteral.create(null, IntegerType)
val nullNullType = Literal.create(null, NullType)

def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)
checkEvaluation(op(nullInt, normalInt), null)
checkEvaluation(op(nullInt, nullInt), null)
checkEvaluation(op(nullNullType, nullNullType), null)
}

nullTest(LessThan)
Expand All @@ -301,6 +327,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(EqualNullSafe(normalInt, nullInt), false)
checkEvaluation(EqualNullSafe(nullInt, normalInt), false)
checkEvaluation(EqualNullSafe(nullInt, nullInt), true)
checkEvaluation(EqualNullSafe(nullNullType, nullNullType), true)
}

test("EqualTo on complex type") {
Expand Down