Skip to content

Commit

Permalink
feat(experimental): Compile match expressions (#7312)
Browse files Browse the repository at this point in the history
Co-authored-by: Maxim Vezenov <mvezenov@gmail.com>
  • Loading branch information
jfecher and vezenovm authored Feb 18, 2025
1 parent f37eedc commit 4c3dee1
Show file tree
Hide file tree
Showing 15 changed files with 1,435 additions and 41 deletions.
155 changes: 154 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod program;
mod value;

use acvm::AcirField;
use noirc_frontend::hir_def::expr::Constructor;
use noirc_frontend::token::FmtStrFragment;
pub(crate) use program::Ssa;

Expand All @@ -11,7 +12,7 @@ use iter_extended::{try_vecmap, vecmap};
use noirc_errors::Location;
use noirc_frontend::ast::{UnaryOp, Visibility};
use noirc_frontend::hir_def::types::Type as HirType;
use noirc_frontend::monomorphization::ast::{self, Expression, Program, While};
use noirc_frontend::monomorphization::ast::{self, Expression, MatchCase, Program, While};

use crate::{
errors::RuntimeError,
Expand Down Expand Up @@ -155,6 +156,7 @@ impl<'a> FunctionContext<'a> {
Expression::Loop(block) => self.codegen_loop(block),
Expression::While(while_) => self.codegen_while(while_),
Expression::If(if_expr) => self.codegen_if(if_expr),
Expression::Match(match_expr) => self.codegen_match(match_expr),
Expression::Tuple(tuple) => self.codegen_tuple(tuple),
Expression::ExtractTupleField(tuple, index) => {
self.codegen_extract_tuple_field(tuple, *index)
Expand Down Expand Up @@ -752,6 +754,157 @@ impl<'a> FunctionContext<'a> {
})
}

fn codegen_match(&mut self, match_expr: &ast::Match) -> Result<Values, RuntimeError> {
let variable = self.lookup(match_expr.variable_to_match);

// Any matches with only a single case we don't need to check the tag at all.
// Note that this includes all matches on struct / tuple values.
if match_expr.cases.len() == 1 && match_expr.default_case.is_none() {
return self.no_match(variable, &match_expr.cases[0]);
}

// From here on we can assume `variable` is an enum, int, or bool value (not a struct/tuple)
let tag = self.enum_tag(&variable);
let tag_type = self.builder.type_of_value(tag).unwrap_numeric();

let end_block = self.builder.insert_block();

// Optimization: if there is no default case we can jump directly to the last case
// when finished with the previous case instead of using a jmpif with an unreachable
// else block.
let last_case = if match_expr.default_case.is_some() {
match_expr.cases.len()
} else {
match_expr.cases.len() - 1
};

for i in 0..last_case {
let case = &match_expr.cases[i];
let variant_tag = self.variant_index_value(&case.constructor, tag_type)?;
let eq = self.builder.insert_binary(tag, BinaryOp::Eq, variant_tag);

let case_block = self.builder.insert_block();
let else_block = self.builder.insert_block();
self.builder.terminate_with_jmpif(eq, case_block, else_block);

self.builder.switch_to_block(case_block);
self.bind_case_arguments(variable.clone(), case);
let results = self.codegen_expression(&case.branch)?.into_value_list(self);
self.builder.terminate_with_jmp(end_block, results);

self.builder.switch_to_block(else_block);
}

if let Some(branch) = &match_expr.default_case {
let results = self.codegen_expression(branch)?.into_value_list(self);
self.builder.terminate_with_jmp(end_block, results);
} else {
// If there is no default case, assume we saved the last case from the
// last_case optimization above
let case = match_expr.cases.last().unwrap();
self.bind_case_arguments(variable, case);
let results = self.codegen_expression(&case.branch)?.into_value_list(self);
self.builder.terminate_with_jmp(end_block, results);
}

self.builder.switch_to_block(end_block);
let result = Self::map_type(&match_expr.typ, |typ| {
self.builder.add_block_parameter(end_block, typ).into()
});
Ok(result)
}

fn variant_index_value(
&mut self,
constructor: &Constructor,
typ: NumericType,
) -> Result<ValueId, RuntimeError> {
match constructor {
Constructor::Int(value) => {
self.checked_numeric_constant(value.field, value.is_negative, typ)
}
other => Ok(self.builder.numeric_constant(other.variant_index(), typ)),
}
}

fn no_match(&mut self, variable: Values, case: &MatchCase) -> Result<Values, RuntimeError> {
if !case.arguments.is_empty() {
self.bind_case_arguments(variable, case);
}
self.codegen_expression(&case.branch)
}

/// Extracts the tag value from an enum. Assumes enums are represented as a tuple
/// where the tag is always the first field of the tuple.
///
/// If the enum is only a single Leaf value, this expects the enum to consist only of the tag value.
fn enum_tag(&mut self, enum_value: &Values) -> ValueId {
match enum_value {
Tree::Branch(values) => self.enum_tag(&values[0]),
Tree::Leaf(value) => value.clone().eval(self),
}
}

/// Bind the given variable ids to each argument of the given enum, using the
/// variant at the given variant index. Note that this function makes assumptions that the
/// representation of an enum is:
///
/// (
/// tag_value,
/// (field0_0, .. field0_N), // fields of variant 0,
/// (field1_0, .. field1_N), // fields of variant 1,
/// ..,
/// (fieldM_0, .. fieldM_N), // fields of variant N,
/// )
fn bind_case_arguments(&mut self, enum_value: Values, case: &MatchCase) {
if !case.arguments.is_empty() {
if case.constructor.is_enum() {
self.bind_enum_case_arguments(enum_value, case);
} else if case.constructor.is_tuple_or_struct() {
self.bind_tuple_or_struct_case_arguments(enum_value, case);
}
}
}

fn bind_enum_case_arguments(&mut self, enum_value: Values, case: &MatchCase) {
let Tree::Branch(mut variants) = enum_value else {
unreachable!("Expected enum value to contain each variant");
};

let variant_index = case.constructor.variant_index();

// variant_index + 1 to account for the extra tag value
let Tree::Branch(variant) = variants.swap_remove(variant_index + 1) else {
unreachable!("Expected enum variant to contain a tag and each variant's arguments");
};

assert_eq!(
variant.len(),
case.arguments.len(),
"Expected enum variant to contain a value for each variant argument"
);

for (value, arg) in variant.into_iter().zip(&case.arguments) {
self.define(*arg, value);
}
}

fn bind_tuple_or_struct_case_arguments(&mut self, struct_value: Values, case: &MatchCase) {
let Tree::Branch(fields) = struct_value else {
unreachable!("Expected struct value to contain each field");
};

assert_eq!(
fields.len(),
case.arguments.len(),
"Expected field length to match constructor argument count"
);

for (value, arg) in fields.into_iter().zip(&case.arguments) {
self.define(*arg, value);
}
}

fn codegen_tuple(&mut self, tuple: &[Expression]) -> Result<Values, RuntimeError> {
Ok(Tree::Branch(try_vecmap(tuple, |expr| self.codegen_expression(expr))?))
}
Expand Down
1 change: 1 addition & 0 deletions compiler/noirc_frontend/src/ast/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl StatementKind {
| (ExpressionKind::Unsafe(..), semi, _)
| (ExpressionKind::Interned(..), semi, _)
| (ExpressionKind::InternedStatement(..), semi, _)
| (ExpressionKind::Match(..), semi, _)
| (ExpressionKind::If(_), semi, _) => {
if semi.is_some() {
StatementKind::Semi(expr)
Expand Down
Loading

0 comments on commit 4c3dee1

Please sign in to comment.