Skip to content

Commit

Permalink
refactor(expr): cleanup non-function-call expression Type variants (#…
Browse files Browse the repository at this point in the history
…10094)

Signed-off-by: Bugen Zhao <i@bugenzhao.com>
  • Loading branch information
BugenZhao authored May 31, 2023
1 parent 1e9af38 commit 487c4c6
Show file tree
Hide file tree
Showing 40 changed files with 172 additions and 196 deletions.
14 changes: 9 additions & 5 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@ option java_package = "com.risingwave.proto";
option optimize_for = SPEED;

message ExprNode {
// TODO: move this into `FunctionCall`.
enum Type {
// `InputRef`, `Constant`, and `UserDefinedFunction` are indicated by the viriant of `rex_node`.
// Their types are therefore deprecated and should be `UNSPECIFIED` instead.
reserved 1, 2, 3000;
reserved "INPUT_REF", "CONSTANT_VALUE", "UDF";

// Used for `InputRef`, `Constant`, and `UserDefinedFunction`.
UNSPECIFIED = 0;
INPUT_REF = 1;
CONSTANT_VALUE = 2;

// arithmetics operators
ADD = 3;
SUBTRACT = 4;
Expand Down Expand Up @@ -192,10 +198,8 @@ message ExprNode {
VNODE = 1101;
// Non-deterministic functions
PROCTIME = 2023;
// User defined functions
UDF = 3000;
}
Type expr_type = 1;
Type function_type = 1;
data.DataType return_type = 3;
oneof rex_node {
uint32 input_ref = 4;
Expand Down
23 changes: 12 additions & 11 deletions src/expr/src/expr/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,22 @@ use crate::{bail, ExprError, Result};
pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
use PbType as E;

match prost.expr_type() {
let func_call = match prost.get_rex_node()? {
RexNode::InputRef(_) => return InputRefExpression::try_from_boxed(prost),
RexNode::Constant(_) => return LiteralExpression::try_from_boxed(prost),
RexNode::Udf(_) => return UdfExpression::try_from_boxed(prost),
RexNode::FuncCall(func_call) => func_call,
};

let func_type = prost.function_type();

match func_type {
// Dedicated types
E::All | E::Some => SomeAllExpression::try_from_boxed(prost),
E::In => InExpression::try_from_boxed(prost),
E::Case => CaseExpression::try_from_boxed(prost),
E::Coalesce => CoalesceExpression::try_from_boxed(prost),
E::ConcatWs => ConcatWsExpression::try_from_boxed(prost),
E::ConstantValue => LiteralExpression::try_from_boxed(prost),
E::InputRef => InputRefExpression::try_from_boxed(prost),
E::Field => FieldExpression::try_from_boxed(prost),
E::Array => NestedConstructExpression::try_from_boxed(prost),
E::Row => NestedConstructExpression::try_from_boxed(prost),
Expand All @@ -62,23 +69,17 @@ pub fn build_from_prost(prost: &ExprNode) -> Result<BoxedExpression> {
ArrayConcatExpression::try_from_boxed(prost)
}
E::Vnode => VnodeExpression::try_from_boxed(prost),
E::Udf => UdfExpression::try_from_boxed(prost),
E::Proctime => ProcTimeExpression::try_from_boxed(prost),

_ => {
let Some(RexNode::FuncCall(call)) = &prost.rex_node else {
return Err(ExprError::UnsupportedFunction(format!("{:?}", prost.rex_node)));
};

let func = prost.expr_type();
let ret_type = DataType::from(prost.get_return_type().unwrap());
let children = call
let children = func_call
.get_children()
.iter()
.map(build_from_prost)
.try_collect()?;

build_func(func, ret_type, children)
build_func(func_type, ret_type, children)
}
}
}
Expand Down
20 changes: 10 additions & 10 deletions src/expr/src/expr/expr_array_concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ impl<'a> TryFrom<&'a ExprNode> for ArrayConcatExpression {
let left_type = left.return_type();
let right_type = right.return_type();
let ret_type = DataType::from(prost.get_return_type()?);
let op = match prost.get_expr_type()? {
let op = match prost.get_function_type()? {
// the types are checked in frontend, so no need for type checking here
Type::ArrayCat => {
if left_type == right_type {
Expand Down Expand Up @@ -389,7 +389,7 @@ mod tests {

fn make_i64_expr_node(value: i64) -> ExprNode {
ExprNode {
expr_type: PbType::ConstantValue as i32,
function_type: PbType::Unspecified as _,
return_type: Some(DataType::Int64.to_protobuf()),
rex_node: Some(RexNode::Constant(PbDatum {
body: value.to_be_bytes().to_vec(),
Expand All @@ -399,7 +399,7 @@ mod tests {

fn make_i64_array_expr_node(values: Vec<i64>) -> ExprNode {
ExprNode {
expr_type: PbType::Array as i32,
function_type: PbType::Array as i32,
return_type: Some(DataType::List(Box::new(DataType::Int64)).to_protobuf()),
rex_node: Some(RexNode::FuncCall(FunctionCall {
children: values.into_iter().map(make_i64_expr_node).collect(),
Expand All @@ -409,7 +409,7 @@ mod tests {

fn make_i64_array_array_expr_node(values: Vec<Vec<i64>>) -> ExprNode {
ExprNode {
expr_type: PbType::Array as i32,
function_type: PbType::Array as i32,
return_type: Some(
DataType::List(Box::new(DataType::List(Box::new(DataType::Int64)))).to_protobuf(),
),
Expand All @@ -425,7 +425,7 @@ mod tests {
let left = make_i64_array_expr_node(vec![42]);
let right = make_i64_array_expr_node(vec![43]);
let expr = ExprNode {
expr_type: PbType::ArrayCat as i32,
function_type: PbType::ArrayCat as i32,
return_type: Some(DataType::List(Box::new(DataType::Int64)).to_protobuf()),
rex_node: Some(RexNode::FuncCall(FunctionCall {
children: vec![left, right],
Expand All @@ -438,7 +438,7 @@ mod tests {
let left = make_i64_array_array_expr_node(vec![vec![42]]);
let right = make_i64_array_array_expr_node(vec![vec![43]]);
let expr = ExprNode {
expr_type: PbType::ArrayCat as i32,
function_type: PbType::ArrayCat as i32,
return_type: Some(DataType::List(Box::new(DataType::Int64)).to_protobuf()),
rex_node: Some(RexNode::FuncCall(FunctionCall {
children: vec![left, right],
Expand All @@ -451,7 +451,7 @@ mod tests {
let left = make_i64_array_expr_node(vec![42]);
let right = make_i64_expr_node(43);
let expr = ExprNode {
expr_type: PbType::ArrayAppend as i32,
function_type: PbType::ArrayAppend as i32,
return_type: Some(DataType::List(Box::new(DataType::Int64)).to_protobuf()),
rex_node: Some(RexNode::FuncCall(FunctionCall {
children: vec![left, right],
Expand All @@ -464,7 +464,7 @@ mod tests {
let left = make_i64_array_array_expr_node(vec![vec![42]]);
let right = make_i64_array_expr_node(vec![43]);
let expr = ExprNode {
expr_type: PbType::ArrayAppend as i32,
function_type: PbType::ArrayAppend as i32,
return_type: Some(DataType::List(Box::new(DataType::Int64)).to_protobuf()),
rex_node: Some(RexNode::FuncCall(FunctionCall {
children: vec![left, right],
Expand All @@ -477,7 +477,7 @@ mod tests {
let left = make_i64_expr_node(43);
let right = make_i64_array_expr_node(vec![42]);
let expr = ExprNode {
expr_type: PbType::ArrayPrepend as i32,
function_type: PbType::ArrayPrepend as i32,
return_type: Some(DataType::List(Box::new(DataType::Int64)).to_protobuf()),
rex_node: Some(RexNode::FuncCall(FunctionCall {
children: vec![left, right],
Expand All @@ -490,7 +490,7 @@ mod tests {
let left = make_i64_array_expr_node(vec![43]);
let right = make_i64_array_array_expr_node(vec![vec![42]]);
let expr = ExprNode {
expr_type: PbType::ArrayPrepend as i32,
function_type: PbType::ArrayPrepend as i32,
return_type: Some(DataType::List(Box::new(DataType::Int64)).to_protobuf()),
rex_node: Some(RexNode::FuncCall(FunctionCall {
children: vec![left, right],
Expand Down
11 changes: 1 addition & 10 deletions src/expr/src/expr/expr_binary_nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,13 @@ use super::{BoxedExpression, Expression};
use crate::vector_op::conjunction::{and, or};
use crate::Result;

#[derive(Debug)]
pub struct BinaryShortCircuitExpression {
expr_ia1: BoxedExpression,
expr_ia2: BoxedExpression,
expr_type: Type,
}

impl std::fmt::Debug for BinaryShortCircuitExpression {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BinaryShortCircuitExpression")
.field("expr_ia1", &self.expr_ia1)
.field("expr_ia2", &self.expr_ia2)
.field("expr_type", &self.expr_type)
.finish()
}
}

#[async_trait::async_trait]
impl Expression for BinaryShortCircuitExpression {
fn return_type(&self) -> DataType {
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/expr/expr_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ impl<'a> TryFrom<&'a ExprNode> for CaseExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
ensure!(prost.get_expr_type().unwrap() == PbType::Case);
ensure!(prost.get_function_type().unwrap() == PbType::Case);

let ret_type = DataType::from(prost.get_return_type().unwrap());
let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else {
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/expr/expr_coalesce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl<'a> TryFrom<&'a ExprNode> for CoalesceExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
ensure!(prost.get_expr_type().unwrap() == Type::Coalesce);
ensure!(prost.get_function_type().unwrap() == Type::Coalesce);

let ret_type = DataType::from(prost.get_return_type().unwrap());
let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else {
Expand Down Expand Up @@ -130,7 +130,7 @@ mod tests {

pub fn make_coalesce_function(children: Vec<ExprNode>, ret: TypeName) -> ExprNode {
ExprNode {
expr_type: Coalesce as i32,
function_type: Coalesce as i32,
return_type: Some(PbDataType {
type_name: ret as i32,
..Default::default()
Expand Down
4 changes: 2 additions & 2 deletions src/expr/src/expr/expr_concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl<'a> TryFrom<&'a ExprNode> for ConcatWsExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
ensure!(prost.get_expr_type().unwrap() == Type::ConcatWs);
ensure!(prost.get_function_type().unwrap() == Type::ConcatWs);

let ret_type = DataType::from(prost.get_return_type().unwrap());
let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else {
Expand Down Expand Up @@ -173,7 +173,7 @@ mod tests {

pub fn make_concat_ws_function(children: Vec<ExprNode>, ret: TypeName) -> ExprNode {
ExprNode {
expr_type: ConcatWs as i32,
function_type: ConcatWs as i32,
return_type: Some(PbDataType {
type_name: ret as i32,
..Default::default()
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/expr/expr_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ impl<'a> TryFrom<&'a ExprNode> for FieldExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
ensure!(prost.get_expr_type().unwrap() == Type::Field);
ensure!(prost.get_function_type().unwrap() == Type::Field);

let ret_type = DataType::from(prost.get_return_type().unwrap());
let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else {
Expand Down
10 changes: 5 additions & 5 deletions src/expr/src/expr/expr_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ impl<'a> TryFrom<&'a ExprNode> for InExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
ensure!(prost.get_expr_type().unwrap() == Type::In);
ensure!(prost.get_function_type().unwrap() == Type::In);

let ret_type = DataType::from(prost.get_return_type().unwrap());
let RexNode::FuncCall(func_call_node) = prost.get_rex_node().unwrap() else {
Expand Down Expand Up @@ -142,7 +142,7 @@ mod tests {
#[test]
fn test_in_expr() {
let input_ref_expr_node = ExprNode {
expr_type: Type::InputRef as i32,
function_type: Type::Unspecified as i32,
return_type: Some(PbDataType {
type_name: TypeName::Varchar as i32,
..Default::default()
Expand All @@ -151,7 +151,7 @@ mod tests {
};
let constant_values = vec![
ExprNode {
expr_type: Type::ConstantValue as i32,
function_type: Type::Unspecified as i32,
return_type: Some(PbDataType {
type_name: TypeName::Varchar as i32,
..Default::default()
Expand All @@ -161,7 +161,7 @@ mod tests {
})),
},
ExprNode {
expr_type: Type::ConstantValue as i32,
function_type: Type::Unspecified as i32,
return_type: Some(PbDataType {
type_name: TypeName::Varchar as i32,
..Default::default()
Expand All @@ -177,7 +177,7 @@ mod tests {
children: in_children,
};
let p = ExprNode {
expr_type: Type::In as i32,
function_type: Type::In as i32,
return_type: Some(PbDataType {
type_name: TypeName::Boolean as i32,
..Default::default()
Expand Down
19 changes: 7 additions & 12 deletions src/expr/src/expr/expr_input_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ use std::ops::Index;
use risingwave_common::array::{ArrayRef, DataChunk};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, Datum};
use risingwave_pb::expr::expr_node::{RexNode, Type};
use risingwave_pb::expr::ExprNode;

use crate::expr::Expression;
use crate::{bail, ensure, ExprError, Result};
use crate::{ExprError, Result};

/// A reference to a column in input relation.
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -65,17 +64,13 @@ impl<'a> TryFrom<&'a ExprNode> for InputRefExpression {
type Error = ExprError;

fn try_from(prost: &'a ExprNode) -> Result<Self> {
ensure!(prost.get_expr_type().unwrap() == Type::InputRef);

let ret_type = DataType::from(prost.get_return_type().unwrap());
if let RexNode::InputRef(input_col_idx) = prost.get_rex_node().unwrap() {
Ok(Self {
return_type: ret_type,
idx: *input_col_idx as _,
})
} else {
bail!("Expect an input ref node")
}
let input_col_idx = prost.get_rex_node().unwrap().as_input_ref().unwrap();

Ok(Self {
return_type: ret_type,
idx: *input_col_idx as _,
})
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/expr/src/expr/expr_jsonb_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ mod tests {
let values = FunctionCall {
children: vec![
ExprNode {
expr_type: Type::ConstantValue as i32,
function_type: Type::Unspecified as i32,
return_type: Some(ProstDataType {
type_name: TypeName::Varchar as i32,
..Default::default()
Expand All @@ -317,7 +317,7 @@ mod tests {
})),
},
ExprNode {
expr_type: Type::ConstantValue as i32,
function_type: Type::Unspecified as i32,
return_type: Some(ProstDataType {
type_name: TypeName::Varchar as i32,
..Default::default()
Expand All @@ -331,7 +331,7 @@ mod tests {
let array_index = FunctionCall {
children: vec![
ExprNode {
expr_type: Type::Array as i32,
function_type: Type::Array as i32,
return_type: Some(ProstDataType {
type_name: TypeName::List as i32,
field_type: vec![ProstDataType {
Expand All @@ -343,7 +343,7 @@ mod tests {
rex_node: Some(RexNode::FuncCall(values)),
},
ExprNode {
expr_type: Type::ConstantValue as i32,
function_type: Type::Unspecified as i32,
return_type: Some(ProstDataType {
type_name: TypeName::Int32 as i32,
..Default::default()
Expand All @@ -355,7 +355,7 @@ mod tests {
],
};
let access = ExprNode {
expr_type: Type::ArrayAccess as i32,
function_type: Type::ArrayAccess as i32,
return_type: Some(ProstDataType {
type_name: TypeName::Varchar as i32,
..Default::default()
Expand Down
Loading

0 comments on commit 487c4c6

Please sign in to comment.