diff --git a/velox/connectors/Connector.h b/velox/connectors/Connector.h index 60c2b1ea9bec..971e528311bf 100644 --- a/velox/connectors/Connector.h +++ b/velox/connectors/Connector.h @@ -47,6 +47,10 @@ namespace facebook::velox::core { class ITypedExpr; } +namespace facebook::velox::core { +struct IndexJoinCondition; +} + namespace facebook::velox::connector { class DataSource; @@ -584,8 +588,8 @@ class Connector { /// Here, /// - 'inputType' is ROW{t.sid, t.event_list} /// - 'numJoinKeys' is 1 since only t.sid is used in join equi-clauses. - /// - 'joinConditions' is list of one expression: contains(t.event_list, - /// u.event_type) + /// - 'joinConditions' specifies the join condition: contains(t.event_list, + /// u.event_type) /// - 'outputType' is ROW{u.event_value} /// - 'tableHandle' specifies the metadata of the index table. /// - 'columnHandles' is a map from 'u.event_type' (in 'joinConditions') and @@ -596,7 +600,7 @@ class Connector { virtual std::shared_ptr createIndexSource( const RowTypePtr& inputType, size_t numJoinKeys, - const std::vector>& + const std::vector>& joinConditions, const RowTypePtr& outputType, const std::shared_ptr& tableHandle, diff --git a/velox/core/PlanNode.cpp b/velox/core/PlanNode.cpp index 0d92fe370355..91dc12eb5b7b 100644 --- a/velox/core/PlanNode.cpp +++ b/velox/core/PlanNode.cpp @@ -40,15 +40,35 @@ std::vector deserializeSources( return {}; } -std::vector deserializeJoinConditions( +namespace { +IndexJoinConditionPtr createIndexJoinCondition( + const folly::dynamic& obj, + void* context) { + VELOX_USER_CHECK_EQ(obj.count("type"), 1); + if (obj["type"] == "in") { + return InIndexJoinCondition::create(obj, context); + } + if (obj["type"] == "between") { + return BetweenIndexJoinCondition::create(obj, context); + } + VELOX_USER_FAIL( + "Unknown index join condition type {}", obj["type"].asString()); +} +} // namespace + +std::vector deserializeJoinConditions( const folly::dynamic& obj, void* context) { if (obj.count("joinConditions") == 0) { return {}; } - return ISerializable::deserialize>( - obj["joinConditions"], context); + std::vector joinConditions; + joinConditions.reserve(obj.count("joinConditions")); + for (const auto& joinCondition : obj["joinConditions"]) { + joinConditions.push_back(createIndexJoinCondition(joinCondition, context)); + } + return joinConditions; } PlanNodePtr deserializeSingleSource(const folly::dynamic& obj, void* context) { @@ -1447,8 +1467,7 @@ PlanNodePtr IndexLookupJoinNode::create( folly::dynamic IndexLookupJoinNode::serialize() const { auto obj = serializeBase(); if (!joinConditions_.empty()) { - folly::dynamic serializedJoins = folly::dynamic::array; - serializedJoins.reserve(joinConditions_.size()); + folly::dynamic serializedJoins = folly::dynamic::array(); for (const auto& joinCondition : joinConditions_) { serializedJoins.push_back(joinCondition->serialize()); } @@ -2881,4 +2900,53 @@ PlanNodePtr FilterNode::create(const folly::dynamic& obj, void* context) { deserializePlanNodeId(obj), filter, std::move(source)); } +folly::dynamic IndexJoinCondition::serialize() const { + folly::dynamic obj = folly::dynamic::object; + obj["key"] = key->serialize(); + return obj; +} + +folly::dynamic InIndexJoinCondition::serialize() const { + folly::dynamic obj = IndexJoinCondition::serialize(); + obj["type"] = "in"; + obj["in"] = list->serialize(); + return obj; +} + +std::string InIndexJoinCondition::toString() const { + return fmt::format("{} IN {}", key->toString(), list->toString()); +} + +IndexJoinConditionPtr InIndexJoinCondition::create( + const folly::dynamic& obj, + void* context) { + return std::make_shared( + ISerializable::deserialize(obj["key"], context), + ISerializable::deserialize(obj["in"], context)); +} + +folly::dynamic BetweenIndexJoinCondition::serialize() const { + folly::dynamic obj = IndexJoinCondition::serialize(); + obj["type"] = "between"; + obj["lower"] = lower->serialize(); + obj["upper"] = upper->serialize(); + return obj; +} + +std::string BetweenIndexJoinCondition::toString() const { + return fmt::format( + "{} BETWEEN {} AND {}", + key->toString(), + lower->toString(), + upper->toString()); +} + +IndexJoinConditionPtr BetweenIndexJoinCondition::create( + const folly::dynamic& obj, + void* context) { + return std::make_shared( + ISerializable::deserialize(obj["key"], context), + ISerializable::deserialize(obj["lower"], context), + ISerializable::deserialize(obj["upper"], context)); +} } // namespace facebook::velox::core diff --git a/velox/core/PlanNode.h b/velox/core/PlanNode.h index 1f2505558f33..441ac77654df 100644 --- a/velox/core/PlanNode.h +++ b/velox/core/PlanNode.h @@ -1768,6 +1768,66 @@ class MergeJoinNode : public AbstractJoinNode { static PlanNodePtr create(const folly::dynamic& obj, void* context); }; +struct IndexJoinCondition; +using IndexJoinConditionPtr = std::shared_ptr; +struct IndexJoinCondition : public ISerializable { + /// References to an index table column. + FieldAccessTypedExprPtr key; + + IndexJoinCondition(FieldAccessTypedExprPtr _key) : key(std::move(_key)) {} + + folly::dynamic serialize() const override; + + virtual std::string toString() const = 0; +}; + +/// Represents IN-LIST index join condition: contains('in', 'key'). 'list' has +/// type of ARRAY(typeof('key')). +struct InIndexJoinCondition : public IndexJoinCondition { + /// References to the probe input column which is ARRAY with element type of + /// the corresponding 'lookupKey' column from index table. + FieldAccessTypedExprPtr list; + + InIndexJoinCondition( + FieldAccessTypedExprPtr _key, + FieldAccessTypedExprPtr _list) + : IndexJoinCondition(std::move(_key)), list(std::move(_list)) {} + + folly::dynamic serialize() const override; + + std::string toString() const override; + + static IndexJoinConditionPtr create(const folly::dynamic& obj, void* context); +}; +using InIndexJoinConditionPtr = std::shared_ptr; + +/// Represents BETWEEN index join condition: 'key' between 'lower' and 'upper'. +/// 'lower' and 'upper' have the same type of 'key'. +struct BetweenIndexJoinCondition : public IndexJoinCondition { + /// The between bound either reference to a probe input column or a constant + /// value. + /// + /// NOTE: the bound is inclusive, and at least one of the bound references to + /// a probe input column. + TypedExprPtr lower; + TypedExprPtr upper; + + BetweenIndexJoinCondition( + FieldAccessTypedExprPtr _key, + TypedExprPtr _lower, + TypedExprPtr _upper) + : IndexJoinCondition(std::move(_key)), + lower(std::move(_lower)), + upper(std::move(_upper)) {} + + folly::dynamic serialize() const override; + + std::string toString() const override; + + static IndexJoinConditionPtr create(const folly::dynamic& obj, void* context); +}; +using BetweenIndexJoinConditionPtr = std::shared_ptr; + /// Represents index lookup join. Translates to an exec::IndexLookupJoin /// operator. Assumes the right input is a table scan source node that provides /// indexed table lookup for the left input with the specified join keys and @@ -1796,8 +1856,8 @@ class MergeJoinNode : public AbstractJoinNode { /// maybe some more) /// - 'leftKeys' is a list of one key 't.sid' /// - 'rightKeys' is a list of one key 'u.sid' -/// - 'joinConditions' is a list of one expression: contains(t.event_list, -/// u.event_type) +/// - 'joinConditions' specifies one condition: contains(t.event_list, +/// u.event_type) /// - 'outputType' contains 3 columns : t.sid, t.day_ts, u.event_type /// class IndexLookupJoinNode : public AbstractJoinNode { @@ -1809,7 +1869,7 @@ class IndexLookupJoinNode : public AbstractJoinNode { JoinType joinType, const std::vector& leftKeys, const std::vector& rightKeys, - const std::vector& joinConditions, + const std::vector& joinConditions, PlanNodePtr left, TableScanNodePtr right, RowTypePtr outputType) @@ -1849,7 +1909,7 @@ class IndexLookupJoinNode : public AbstractJoinNode { return lookupSourceNode_; } - const std::vector& joinConditions() const { + const std::vector& joinConditions() const { return joinConditions_; } @@ -1869,7 +1929,7 @@ class IndexLookupJoinNode : public AbstractJoinNode { const TableScanNodePtr lookupSourceNode_; - const std::vector joinConditions_; + const std::vector joinConditions_; }; /// Represents inner/outer nested loop joins. Translates to an diff --git a/velox/exec/IndexLookupJoin.cpp b/velox/exec/IndexLookupJoin.cpp index 0df9b44fb861..b806b41a8fdf 100644 --- a/velox/exec/IndexLookupJoin.cpp +++ b/velox/exec/IndexLookupJoin.cpp @@ -31,7 +31,102 @@ void duplicateJoinKeyCheck( } VELOX_USER_CHECK_EQ(lookupKeyNames.size(), keys.size()); } + +std::string getColumnName(const core::TypedExprPtr& typeExpr) { + const auto field = core::TypedExprs::asFieldAccess(typeExpr); + VELOX_USER_CHECK(field->isInputColumn()); + return field->name(); +} + +// Adds a probe input column to lookup input channels and type if the probe +// column is used in a join condition, The lookup input is projected from the +// probe input and feeds into index source for lookup. +void addLookupInputColumn( + const std::string& columnName, + const TypePtr& columnType, + column_index_t columnChannel, + std::vector& lookupInputNames, + std::vector& lookupInputTypes, + std::vector& lookupInputChannels, + folly::F14FastSet& lookupInputNameSet) { + if (lookupInputNameSet.count(columnName) != 0) { + return; + } + lookupInputNames.emplace_back(columnName); + lookupInputTypes.emplace_back(columnType); + lookupInputChannels.emplace_back(columnChannel); + lookupInputNameSet.insert(columnName); +} + +// Validates one of between bound, and update the lookup input channels and type +// to include the corresponding probe input column if the bound is not constant. +bool addBetweenConditionBound( + const core::TypedExprPtr& typeExpr, + const RowTypePtr& inputType, + const TypePtr& indexKeyType, + std::vector& lookupInputNames, + std::vector& lookupInputTypes, + std::vector& lookupInputChannels, + folly::F14FastSet& lookupInputNameSet) { + const bool isConstant = core::TypedExprs::isConstant(typeExpr); + if (!isConstant) { + const auto conditionColumnName = getColumnName(typeExpr); + const auto conditionColumnChannel = + inputType->getChildIdx(conditionColumnName); + const auto conditionColumnType = inputType->childAt(conditionColumnChannel); + VELOX_USER_CHECK(conditionColumnType->equivalent(*indexKeyType)); + addLookupInputColumn( + conditionColumnName, + conditionColumnType, + conditionColumnChannel, + lookupInputNames, + lookupInputTypes, + lookupInputChannels, + lookupInputNameSet); + } else { + VELOX_USER_CHECK(core::TypedExprs::asConstant(typeExpr)->type()->equivalent( + *indexKeyType)); + } + return isConstant; +} + +// Process a between join condition by validating the lower and upper bound +// types, and updating the lookup input channels and type to include the probe +// input columns which contain the between condition bounds. +void addBetweenCondition( + const core::BetweenIndexJoinConditionPtr& betweenCondition, + const RowTypePtr& inputType, + const TypePtr& indexKeyType, + std::vector& lookupInputNames, + std::vector& lookupInputTypes, + std::vector& lookupInputChannels, + folly::F14FastSet& lookupInputNameSet) { + size_t numConstants{0}; + numConstants += !!addBetweenConditionBound( + betweenCondition->lower, + inputType, + indexKeyType, + lookupInputNames, + lookupInputTypes, + lookupInputChannels, + lookupInputNameSet); + numConstants += !!addBetweenConditionBound( + betweenCondition->upper, + inputType, + indexKeyType, + lookupInputNames, + lookupInputTypes, + lookupInputChannels, + lookupInputNameSet); + + VELOX_USER_CHECK_LT( + numConstants, + 2, + "At least one of the between condition bounds needs to be not constant: {}", + betweenCondition->toString()); +} } // namespace + IndexLookupJoin::IndexLookupJoin( int32_t operatorId, DriverCtx* driverCtx, @@ -50,6 +145,7 @@ IndexLookupJoin::IndexLookupJoin( probeType_{joinNode->sources()[0]->outputType()}, lookupType_{joinNode->lookupSource()->outputType()}, lookupTableHandle_{joinNode->lookupSource()->tableHandle()}, + lookupConditions_{joinNode->joinConditions()}, lookupColumnHandles_(joinNode->lookupSource()->assignments()), connectorQueryCtx_{operatorCtx_->createConnectorQueryCtx( lookupTableHandle_->connectorId(), @@ -61,7 +157,6 @@ IndexLookupJoin::IndexLookupJoin( operatorType(), lookupTableHandle_->connectorId()), spillConfig_.has_value() ? &(spillConfig_.value()) : nullptr)}, - expressionEvaluator_(connectorQueryCtx_->expressionEvaluator()), connector_(connector::getConnector(lookupTableHandle_->connectorId())), joinNode_{joinNode} { duplicateJoinKeyCheck(joinNode_->leftKeys()); @@ -107,76 +202,74 @@ void IndexLookupJoin::initLookupInput() { VELOX_CHECK_EQ(lookupInputNames.size(), lookupInputChannels_.size()); lookupInputType_ = ROW(std::move(lookupInputNames), std::move(lookupInputTypes)); + VELOX_CHECK_EQ(lookupInputType_->size(), lookupInputChannels_.size()); }; - // List probe key columns used in join-equi caluse first. - folly::F14FastSet probeKeyColumnNames; + folly::F14FastSet lookupInputColumnSet; + folly::F14FastSet lookupIndexColumnSet; + // List probe columns used in join-equi caluse first. for (auto keyIdx = 0; keyIdx < numKeys_; ++keyIdx) { - lookupInputNames.emplace_back(joinNode_->leftKeys()[keyIdx]->name()); - const auto probeKeyChannel = - probeType_->getChildIdx(lookupInputNames.back()); - lookupInputChannels_.emplace_back(probeKeyChannel); - lookupInputTypes.emplace_back(probeType_->childAt(probeKeyChannel)); - VELOX_CHECK_EQ(probeKeyColumnNames.count(lookupInputNames.back()), 0); - probeKeyColumnNames.insert(lookupInputNames.back()); + const auto probeKeyName = joinNode_->leftKeys()[keyIdx]->name(); + const auto indexKeyName = joinNode_->rightKeys()[keyIdx]->name(); + VELOX_USER_CHECK_EQ(lookupIndexColumnSet.count(indexKeyName), 0); + lookupIndexColumnSet.insert(indexKeyName); + const auto probeKeyChannel = probeType_->getChildIdx(probeKeyName); + const auto probeKeyType = probeType_->childAt(probeKeyChannel); + VELOX_USER_CHECK( + lookupType_->findChild(indexKeyName)->equivalent(*probeKeyType)); + addLookupInputColumn( + indexKeyName, + probeKeyType, + probeKeyChannel, + lookupInputNames, + lookupInputTypes, + lookupInputChannels_, + lookupInputColumnSet); } if (lookupConditions_.empty()) { return; } - folly::F14FastSet probeConditionColumnNames; - folly::F14FastSet lookupConditionColumnNames; for (const auto& lookupCondition : lookupConditions_) { - const auto lookupConditionExprSet = - expressionEvaluator_->compile(lookupCondition); - const auto& lookupConditionExpr = lookupConditionExprSet->expr(0); - - int numProbeColumns{0}; - int numLookupColumns{0}; - for (auto& input : lookupConditionExpr->distinctFields()) { - const auto& columnName = input->field(); - auto probeIndexOpt = probeType_->getChildIdxIfExists(columnName); - if (probeIndexOpt.has_value()) { - ++numProbeColumns; - // There is no overlap between probe key columns and probe condition - // columns. - VELOX_CHECK_EQ(probeKeyColumnNames.count(columnName), 0); - // We allow the probe column used in more than one lookup conditions. - if (probeConditionColumnNames.count(columnName) == 0) { - probeConditionColumnNames.insert(columnName); - lookupInputChannels_.push_back(probeIndexOpt.value()); - lookupInputNames.push_back(columnName); - lookupInputTypes.push_back(input->type()); - } - continue; - } - - ++numLookupColumns; - auto lookupIndexOpt = lookupType_->getChildIdxIfExists(columnName); - VELOX_CHECK( - lookupIndexOpt.has_value(), - "Lookup condition column {} is not found", - columnName); - // A lookup column can only be used in one lookup condition. - VELOX_CHECK_EQ( - lookupConditionColumnNames.count(columnName), - 0, - "Lookup condition column {} from lookup table used in more than one lookup conditions", - input->field()); - lookupConditionColumnNames.insert(input->field()); + const auto indexKeyName = getColumnName(lookupCondition->key); + VELOX_USER_CHECK_EQ(lookupIndexColumnSet.count(indexKeyName), 0); + lookupIndexColumnSet.insert(indexKeyName); + const auto indexKeyType = lookupType_->findChild(indexKeyName); + + if (const auto inCondition = + std::dynamic_pointer_cast( + lookupCondition)) { + const auto conditionInputName = getColumnName(inCondition->list); + const auto conditionInputChannel = + probeType_->getChildIdx(conditionInputName); + const auto conditionInputType = + probeType_->childAt(conditionInputChannel); + const auto expectedConditionInputType = ARRAY(indexKeyType); + VELOX_USER_CHECK( + conditionInputType->equivalent(*expectedConditionInputType)); + addLookupInputColumn( + conditionInputName, + conditionInputType, + conditionInputChannel, + lookupInputNames, + lookupInputTypes, + lookupInputChannels_, + lookupInputColumnSet); } - VELOX_CHECK_EQ( - numLookupColumns, - 1, - "Unexpected number of lookup columns in lookup condition {}", - lookupConditionExpr->toString()); - VELOX_CHECK_GT( - numProbeColumns, - 0, - "No probe columns found in lookup condition {}", - lookupConditionExpr->toString()); + if (const auto betweenCondition = + std::dynamic_pointer_cast( + lookupCondition)) { + addBetweenCondition( + betweenCondition, + probeType_, + indexKeyType, + lookupInputNames, + lookupInputTypes, + lookupInputChannels_, + lookupInputColumnSet); + } } } diff --git a/velox/exec/IndexLookupJoin.h b/velox/exec/IndexLookupJoin.h index 5931f96b1711..fabc82bd61c3 100644 --- a/velox/exec/IndexLookupJoin.h +++ b/velox/exec/IndexLookupJoin.h @@ -92,11 +92,10 @@ class IndexLookupJoin : public Operator { const RowTypePtr probeType_; const RowTypePtr lookupType_; const std::shared_ptr lookupTableHandle_; - const std::vector lookupConditions_; + const std::vector lookupConditions_; std::unordered_map> lookupColumnHandles_; const std::shared_ptr connectorQueryCtx_; - core::ExpressionEvaluator* const expressionEvaluator_; const std::shared_ptr connector_; // The lookup join plan node used to initialize this operator and reset after diff --git a/velox/exec/tests/IndexLookupJoinTest.cpp b/velox/exec/tests/IndexLookupJoinTest.cpp index 72d30a708fad..de0caf7ddea3 100644 --- a/velox/exec/tests/IndexLookupJoinTest.cpp +++ b/velox/exec/tests/IndexLookupJoinTest.cpp @@ -323,16 +323,24 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { kTestIndexConnectorName, nullptr, true); auto left = makeRowVector( - {"t0", "t1", "t2"}, - {makeFlatVector({1, 2, 3}), + {"t0", "t1", "t2", "t3", "t4"}, + {makeFlatVector({1, 2, 3}), makeFlatVector({10, 20, 30}), - makeFlatVector({10, 30, 20})}); + makeFlatVector({10, 30, 20}), + makeArrayVector( + 3, + [](auto row) { return row; }, + [](auto /*unused*/, auto index) { return index; }), + makeArrayVector( + 3, + [](auto row) { return row; }, + [](auto /*unused*/, auto index) { return index; })}); auto right = makeRowVector( {"u0", "u1", "u2"}, - {makeFlatVector({1, 2, 3}), + {makeFlatVector({1, 2, 3}), makeFlatVector({10, 20, 30}), - makeFlatVector({10, 30, 20})}); + makeFlatVector({10, 30, 20})}); auto planNodeIdGenerator = std::make_shared(); @@ -372,21 +380,65 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { testSerde(plan); } - // with join conditions. + // with in join conditions. for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { - auto plan = PlanBuilder(planNodeIdGenerator) + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .indexLookupJoin( + {"t0"}, + {"u0"}, + indexTableScan, + {"contains(t3, u0)", "contains(t4, u1)"}, + {"t0", "u1", "t2", "t1"}, + joinType) + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with between join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) .values({left}) .indexLookupJoin( {"t0"}, {"u0"}, indexTableScan, - {"u1 > t2"}, + {"u0 between t0 AND t1", + "u1 between t1 AND 10", + "u1 between 10 AND t1"}, {"t0", "u1", "t2", "t1"}, joinType) .planNode(); auto indexLookupJoinNode = std::dynamic_pointer_cast(plan); - ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 1); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 3); + ASSERT_EQ( + indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), + kTestIndexConnectorName); + testSerde(plan); + } + + // with mix join conditions. + for (const auto joinType : {core::JoinType::kLeft, core::JoinType::kInner}) { + auto plan = PlanBuilder(planNodeIdGenerator, pool_.get()) + .values({left}) + .indexLookupJoin( + {"t0"}, + {"u0"}, + indexTableScan, + {"contains(t3, u0)", "u1 between 10 AND t1"}, + {"t0", "u1", "t2", "t1"}, + joinType) + .planNode(); + auto indexLookupJoinNode = + std::dynamic_pointer_cast(plan); + ASSERT_EQ(indexLookupJoinNode->joinConditions().size(), 2); ASSERT_EQ( indexLookupJoinNode->lookupSource()->tableHandle()->connectorId(), kTestIndexConnectorName); @@ -429,7 +481,7 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { {"t0", "t1"}, {"u0"}, indexTableScan, - {"u1 > t2"}, + {"contains(t4, u0)"}, {"t0", "u1", "t2", "t1"}) .planNode(), "JoinNode requires same number of join keys on left and right sides"); @@ -441,7 +493,11 @@ TEST_P(IndexLookupJoinTest, planNodeAndSerde) { PlanBuilder(planNodeIdGenerator) .values({left}) .indexLookupJoin( - {}, {}, indexTableScan, {"u1 > t2"}, {"t0", "u1", "t2", "t1"}) + {}, + {}, + indexTableScan, + {"contains(t4, u0)"}, + {"t0", "u1", "t2", "t1"}) .planNode(), "JoinNode requires at least one join key"); } diff --git a/velox/exec/tests/utils/PlanBuilder.cpp b/velox/exec/tests/utils/PlanBuilder.cpp index 189e2ab45fc4..4400c96cb0fd 100644 --- a/velox/exec/tests/utils/PlanBuilder.cpp +++ b/velox/exec/tests/utils/PlanBuilder.cpp @@ -65,6 +65,44 @@ std::shared_ptr buildHiveBucketProperty( bucketTypes, sortBy); } + +core::IndexJoinConditionPtr parseJoinCondition( + const std::string& joinCondition, + const RowTypePtr& rowType, + const parse::ParseOptions& options, + memory::MemoryPool* pool) { + const auto joinConditionExpr = + parseExpr(joinCondition, rowType, options, pool); + const auto typedCallExpr = + std::dynamic_pointer_cast(joinConditionExpr); + VELOX_CHECK_NOT_NULL(typedCallExpr); + if (typedCallExpr->name() == "contains") { + VELOX_CHECK_EQ(typedCallExpr->inputs().size(), 2); + auto keyColumnExpr = + std::dynamic_pointer_cast( + typedCallExpr->inputs()[1]); + VELOX_CHECK_NOT_NULL(keyColumnExpr); + auto conditionColumnExpr = + std::dynamic_pointer_cast( + typedCallExpr->inputs()[0]); + VELOX_CHECK_NOT_NULL(conditionColumnExpr); + return std::make_shared( + std::move(keyColumnExpr), std::move(conditionColumnExpr)); + } + + if (typedCallExpr->name() == "between") { + VELOX_CHECK_EQ(typedCallExpr->inputs().size(), 3); + auto keyColumnExpr = + std::dynamic_pointer_cast( + typedCallExpr->inputs()[0]); + VELOX_CHECK_NOT_NULL(keyColumnExpr); + const auto& lowerExpr = typedCallExpr->inputs()[1]; + const auto& upperExpr = typedCallExpr->inputs()[2]; + return std::make_shared( + std::move(keyColumnExpr), lowerExpr, upperExpr); + } + VELOX_USER_FAIL("Invalid index join condition: {}", joinCondition); +} } // namespace PlanBuilder& PlanBuilder::tableScan( @@ -1590,11 +1628,11 @@ PlanBuilder& PlanBuilder::indexLookupJoin( auto leftKeyFields = fields(planNode_->outputType(), leftKeys); auto rightKeyFields = fields(right->outputType(), rightKeys); - std::vector joinConditionExprs{}; - joinConditionExprs.reserve(joinConditions.size()); + std::vector joinConditionPtrs{}; + joinConditionPtrs.reserve(joinConditions.size()); for (const auto& joinCondition : joinConditions) { - joinConditionExprs.push_back( - parseExpr(joinCondition, inputType, options_, pool_)); + joinConditionPtrs.push_back( + parseJoinCondition(joinCondition, inputType, options_, pool_)); } planNode_ = std::make_shared( @@ -1602,7 +1640,7 @@ PlanBuilder& PlanBuilder::indexLookupJoin( joinType, std::move(leftKeyFields), std::move(rightKeyFields), - std::move(joinConditionExprs), + std::move(joinConditionPtrs), std::move(planNode_), right, std::move(outputType)); diff --git a/velox/exec/tests/utils/PlanBuilder.h b/velox/exec/tests/utils/PlanBuilder.h index 480d0faf068e..d329e2e02818 100644 --- a/velox/exec/tests/utils/PlanBuilder.h +++ b/velox/exec/tests/utils/PlanBuilder.h @@ -1122,7 +1122,15 @@ class PlanBuilder { /// @param right The right input source with index lookup support. /// @param joinCondition SQL expressions as the join conditions. Each join /// condition must use columns from both sides. For the right side, it can - /// only use one index column. + /// only use one index column. Currently we support "in" and "between" join + /// conditions: + /// "in" condition is written as SQL expression as "contains(a, b)" where "b" + /// is the index column from right side and "a" is the condition column from + /// left side. b has type T and a has type ARRAT(T). + /// "between" condition is written as SQL expression as "a between b and c" + /// where "a" is the index column from right side and "b", "c" are either + /// condition column from left side or a constant but at least one of them + /// must not be constant. They all have the same type. /// @param joinType Type of the join supported: inner, left. /// /// See hashJoin method for the description of the other parameters. diff --git a/velox/exec/tests/utils/TestIndexStorageConnector.cpp b/velox/exec/tests/utils/TestIndexStorageConnector.cpp index cc92f706ec0a..f496793d67dd 100644 --- a/velox/exec/tests/utils/TestIndexStorageConnector.cpp +++ b/velox/exec/tests/utils/TestIndexStorageConnector.cpp @@ -230,7 +230,7 @@ TestIndexConnector::TestIndexConnector( std::shared_ptr TestIndexConnector::createIndexSource( const RowTypePtr& inputType, size_t numJoinKeys, - const std::vector>& joinConditions, + const std::vector& joinConditions, const RowTypePtr& outputType, const std::shared_ptr& tableHandle, const std::unordered_map< diff --git a/velox/exec/tests/utils/TestIndexStorageConnector.h b/velox/exec/tests/utils/TestIndexStorageConnector.h index 668eafc2bba0..169e7d32bd38 100644 --- a/velox/exec/tests/utils/TestIndexStorageConnector.h +++ b/velox/exec/tests/utils/TestIndexStorageConnector.h @@ -245,8 +245,7 @@ class TestIndexConnector : public connector::Connector { std::shared_ptr createIndexSource( const RowTypePtr& inputType, size_t numJoinKeys, - const std::vector>& - joinConditions, + const std::vector& joinConditions, const RowTypePtr& outputType, const std::shared_ptr& tableHandle, const std::unordered_map<