Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Keep inc_rc for array outputs during preprocessing #7163

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions compiler/noirc_evaluator/src/ssa/opt/die.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,16 @@ impl Context {
let block = &function.dfg[block_id];
self.mark_terminator_values_as_used(function, block);

let instructions_len = block.instructions().len();

let mut rc_tracker = RcTracker::default();
rc_tracker.mark_terminator_arrays_as_used(function, block);

let instructions_len = block.instructions().len();

// Indexes of instructions that might be out of bounds.
// We'll remove those, but before that we'll insert bounds checks for them.
let mut possible_index_out_of_bounds_indexes = Vec::new();

// Going in reverse so we know if a result of an instruction was used.
for (instruction_index, instruction_id) in block.instructions().iter().rev().enumerate() {
let instruction = &function.dfg[*instruction_id];

Expand Down Expand Up @@ -241,6 +243,8 @@ impl Context {
}
}

/// Go through the RC instructions collected when we figured out which values were unused;
/// for each RC that refers to an unused value, remove the RC as well.
fn remove_rc_instructions(&self, dfg: &mut DataFlowGraph) {
let unused_rc_values_by_block: HashMap<BasicBlockId, HashSet<InstructionId>> =
self.rc_instructions.iter().fold(HashMap::default(), |mut acc, (rc, block)| {
Expand Down Expand Up @@ -580,10 +584,12 @@ struct RcTracker {
// with the same value but no array set in between.
// If we see an inc/dec RC pair within a block we can safely remove both instructions.
rcs_with_possible_pairs: HashMap<Type, Vec<RcInstruction>>,
// Tracks repeated RC instructions: if there are two `inc_rc` for the same value in a row, the 2nd one is redundant.
rc_pairs_to_remove: HashSet<InstructionId>,
// We also separately track all IncrementRc instructions and all array types which have been mutably borrowed.
// If an array is the same type as one of those non-mutated array types, we can safely remove all IncrementRc instructions on that array.
inc_rcs: HashMap<ValueId, HashSet<InstructionId>>,
// When tracking mutations we consider arrays with the same type as all being possibly mutated.
mutated_array_types: HashSet<Type>,
// The SSA often creates patterns where after simplifications we end up with repeat
// IncrementRc instructions on the same value. We track whether the previous instruction was an IncrementRc,
Expand All @@ -593,9 +599,19 @@ struct RcTracker {
}

impl RcTracker {
fn mark_terminator_arrays_as_used(&mut self, function: &Function, block: &BasicBlock) {
block.unwrap_terminator().for_each_value(|value| {
let typ = function.dfg.type_of_value(value);
if matches!(&typ, Type::Array(_, _) | Type::Slice(_)) {
self.mutated_array_types.insert(typ);
}
});
}

fn track_inc_rcs_to_remove(&mut self, instruction_id: InstructionId, function: &Function) {
let instruction = &function.dfg[instruction_id];

// Deduplicate IncRC instructions.
if let Instruction::IncrementRc { value } = instruction {
if let Some(previous_value) = self.previous_inc_rc {
if previous_value == *value {
Expand All @@ -604,13 +620,16 @@ impl RcTracker {
}
self.previous_inc_rc = Some(*value);
} else {
// Reset the deduplication.
self.previous_inc_rc = None;
}

// DIE loops over a block in reverse order, so we insert an RC instruction for possible removal
// when we see a DecrementRc and check whether it was possibly mutated when we see an IncrementRc.
match instruction {
Instruction::IncrementRc { value } => {
// Get any RC instruction recorded further down the block for this array;
// if it exists and not marked as mutated, then both RCs can be removed.
if let Some(inc_rc) =
pop_rc_for(*value, function, &mut self.rcs_with_possible_pairs)
{
Expand All @@ -619,7 +638,7 @@ impl RcTracker {
self.rc_pairs_to_remove.insert(instruction_id);
}
}

// Remember that this array was RC'd by this instruction.
self.inc_rcs.entry(*value).or_default().insert(instruction_id);
}
Instruction::DecrementRc { value } => {
Expand All @@ -632,12 +651,12 @@ impl RcTracker {
}
Instruction::ArraySet { array, .. } => {
let typ = function.dfg.type_of_value(*array);
// We mark all RCs that refer to arrays with a matching type as the one being set, as possibly mutated.
if let Some(dec_rcs) = self.rcs_with_possible_pairs.get_mut(&typ) {
for dec_rc in dec_rcs {
dec_rc.possibly_mutated = true;
}
}

self.mutated_array_types.insert(typ);
}
Instruction::Store { value, .. } => {
Expand All @@ -648,6 +667,9 @@ impl RcTracker {
}
}
Instruction::Call { arguments, .. } => {
// Treat any array-type arguments to calls as possible sources of mutation.
// During the preprocessing of functions in isolation we don't want to
// get rid of IncRCs arrays that can potentially be mutated outside.
for arg in arguments {
let typ = function.dfg.type_of_value(*arg);
if matches!(&typ, Type::Array(..) | Type::Slice(..)) {
Expand All @@ -659,6 +681,7 @@ impl RcTracker {
}
}

/// Get all RC instructions which work on arrays whose type has not been marked as mutated.
fn get_non_mutated_arrays(&self, dfg: &DataFlowGraph) -> HashSet<InstructionId> {
self.inc_rcs
.keys()
Expand Down Expand Up @@ -857,16 +880,6 @@ mod test {

#[test]
fn keep_inc_rc_on_borrowed_array_set() {
// brillig(inline) fn main f0 {
// b0(v0: [u32; 2]):
// inc_rc v0
// v3 = array_set v0, index u32 0, value u32 1
// inc_rc v0
// inc_rc v0
// inc_rc v0
// v4 = array_get v3, index u32 1
// return v4
// }
let src = "
brillig(inline) fn main f0 {
b0(v0: [u32; 2]):
Expand Down Expand Up @@ -951,6 +964,36 @@ mod test {
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn do_not_remove_inc_rcs_for_arrays_in_terminator() {
let src = "
brillig(inline) fn main f0 {
b0(v0: [Field; 2]):
inc_rc v0
inc_rc v0
inc_rc v0
v2 = array_get v0, index u32 0 -> Field
inc_rc v0
return v0, v2
}
";

let ssa = Ssa::from_str(src).unwrap();

let expected = "
brillig(inline) fn main f0 {
b0(v0: [Field; 2]):
inc_rc v0
v2 = array_get v0, index u32 0 -> Field
inc_rc v0
return v0, v2
}
";

let ssa = ssa.dead_instruction_elimination();
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn do_not_remove_inc_rc_if_used_as_call_arg() {
// We do not want to remove inc_rc instructions on values
Expand Down
26 changes: 11 additions & 15 deletions compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ impl Ssa {
// Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them.
let bottom_up = inlining::compute_bottom_up_order(&self);

// As a heuristic to avoid optimizing functions near the entry point, find a cutoff weight.
let total_weight =
bottom_up.iter().fold(0usize, |acc, (_, (_, w))| (acc.saturating_add(*w)));
let mean_weight = total_weight / bottom_up.len();
let cutoff_weight = mean_weight;

// Preliminary inlining decisions.
let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness);

Expand All @@ -36,19 +30,21 @@ impl Ssa {
};

for (id, (own_weight, transitive_weight)) in bottom_up {
// Skip preprocessing heavy functions that gained most of their weight from transitive accumulation.
let function = &self.functions[&id];

// Skip preprocessing heavy functions that gained most of their weight from transitive accumulation, which tend to be near the entry.
// These can be processed later by the regular SSA passes.
if transitive_weight >= cutoff_weight && transitive_weight > own_weight * 2 {
continue;
}
let is_heavy = transitive_weight > own_weight * 10;

// Functions which are inline targets will be processed in later passes.
// Here we want to treat the functions which will be inlined into them.
if let Some(info) = inline_infos.get(&id) {
if info.is_inline_target() {
continue;
}
let is_target =
inline_infos.get(&id).map(|info| info.is_inline_target()).unwrap_or_default();

if is_heavy || is_target {
continue;
}
let function = &self.functions[&id];

// Start with an inline pass.
let mut function = function.inlined(&self, &should_inline_call);
// Help unrolling determine bounds.
Expand Down
7 changes: 7 additions & 0 deletions test_programs/execution_success/regression_11294/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "regression_11294"
version = "0.1.0"
type = "bin"
authors = [""]

[dependencies]
47 changes: 47 additions & 0 deletions test_programs/execution_success/regression_11294/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0c78b411fc893c51d446c08daa5741b9ba6103126c9e450bed90fcde8793168a"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000002"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000007"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"

[[previous_kernel_public_inputs.end.private_call_stack]]
args_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
returns_hash = "0x0000000000000000000000000000000000000000000000000000000000000000"
start_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
end_side_effect_counter = "0x0000000000000000000000000000000000000000000000000000000000000000"
Loading
Loading