diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index b4ea06f1030..c24cc1a9c86 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -32,7 +32,7 @@ use crate::{ Kind, QuotedType, Shared, StructType, Type, }; -use super::{Elaborator, LambdaContext}; +use super::{Elaborator, LambdaContext, UnsafeBlockStatus}; impl<'context> Elaborator<'context> { pub(crate) fn elaborate_expression(&mut self, expr: Expression) -> (ExprId, Type) { @@ -59,7 +59,7 @@ impl<'context> Elaborator<'context> { return self.elaborate_comptime_block(comptime, expr.span) } ExpressionKind::Unsafe(block_expression, _) => { - self.elaborate_unsafe_block(block_expression) + self.elaborate_unsafe_block(block_expression, expr.span) } ExpressionKind::Resolved(id) => return (id, self.interner.id_type(id)), ExpressionKind::Interned(id) => { @@ -140,15 +140,36 @@ impl<'context> Elaborator<'context> { (HirBlockExpression { statements }, block_type) } - fn elaborate_unsafe_block(&mut self, block: BlockExpression) -> (HirExpression, Type) { + fn elaborate_unsafe_block( + &mut self, + block: BlockExpression, + span: Span, + ) -> (HirExpression, Type) { // Before entering the block we cache the old value of `in_unsafe_block` so it can be restored. - let old_in_unsafe_block = self.in_unsafe_block; - self.in_unsafe_block = true; + let old_in_unsafe_block = self.unsafe_block_status; + let is_nested_unsafe_block = + !matches!(old_in_unsafe_block, UnsafeBlockStatus::NotInUnsafeBlock); + if is_nested_unsafe_block { + let span = Span::from(span.start()..span.start() + 6); // Only highlight the `unsafe` keyword + self.push_err(TypeCheckError::NestedUnsafeBlock { span }); + } + + self.unsafe_block_status = UnsafeBlockStatus::InUnsafeBlockWithoutUnconstrainedCalls; let (hir_block_expression, typ) = self.elaborate_block_expression(block); - // Finally, we restore the original value of `self.in_unsafe_block`. - self.in_unsafe_block = old_in_unsafe_block; + if let UnsafeBlockStatus::InUnsafeBlockWithoutUnconstrainedCalls = self.unsafe_block_status + { + let span = Span::from(span.start()..span.start() + 6); // Only highlight the `unsafe` keyword + self.push_err(TypeCheckError::UnnecessaryUnsafeBlock { span }); + } + + // Finally, we restore the original value of `self.in_unsafe_block`, + // but only if this isn't a nested unsafe block (that way if we found an unconstrained call + // for this unsafe block we'll also consider the outer one as finding one, and we don't double error) + if !is_nested_unsafe_block { + self.unsafe_block_status = old_in_unsafe_block; + } (HirExpression::Unsafe(hir_block_expression), typ) } diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 55144f8944a..593ea6b20e8 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -79,6 +79,16 @@ pub struct LambdaContext { pub scope_index: usize, } +/// Determines whether we are in an unsafe block and, if so, whether +/// any unconstrained calls were found in it (because if not we'll warn +/// that the unsafe block is not needed). +#[derive(Copy, Clone)] +enum UnsafeBlockStatus { + NotInUnsafeBlock, + InUnsafeBlockWithoutUnconstrainedCalls, + InUnsafeBlockWithConstrainedCalls, +} + pub struct Elaborator<'context> { scopes: ScopeForest, @@ -90,7 +100,7 @@ pub struct Elaborator<'context> { pub(crate) file: FileId, - in_unsafe_block: bool, + unsafe_block_status: UnsafeBlockStatus, nested_loops: usize, /// Contains a mapping of the current struct or functions's generics to @@ -202,7 +212,7 @@ impl<'context> Elaborator<'context> { def_maps, usage_tracker, file: FileId::dummy(), - in_unsafe_block: false, + unsafe_block_status: UnsafeBlockStatus::NotInUnsafeBlock, nested_loops: 0, generics: Vec::new(), lambda_stack: Vec::new(), diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 9aafd690bb6..550ee41fbd4 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -35,7 +35,7 @@ use crate::{ Generics, Kind, ResolvedGeneric, Type, TypeBinding, TypeBindings, UnificationError, }; -use super::{lints, path_resolution::PathResolutionItem, Elaborator}; +use super::{lints, path_resolution::PathResolutionItem, Elaborator, UnsafeBlockStatus}; pub const SELF_TYPE_NAME: &str = "Self"; @@ -1483,8 +1483,14 @@ impl<'context> Elaborator<'context> { func_type_is_unconstrained || self.is_unconstrained_call(call.func); let crossing_runtime_boundary = is_current_func_constrained && is_unconstrained_call; if crossing_runtime_boundary { - if !self.in_unsafe_block { - self.push_err(TypeCheckError::Unsafe { span }); + match self.unsafe_block_status { + UnsafeBlockStatus::NotInUnsafeBlock => { + self.push_err(TypeCheckError::Unsafe { span }); + } + UnsafeBlockStatus::InUnsafeBlockWithoutUnconstrainedCalls => { + self.unsafe_block_status = UnsafeBlockStatus::InUnsafeBlockWithConstrainedCalls; + } + UnsafeBlockStatus::InUnsafeBlockWithConstrainedCalls => (), } if let Some(called_func_id) = self.interner.lookup_function_from_expr(&call.func) { diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index 08864b919e3..16422e0ef8b 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -205,6 +205,10 @@ pub enum TypeCheckError { CyclicType { typ: Type, span: Span }, #[error("Type annotations required before indexing this array or slice")] TypeAnnotationsNeededForIndex { span: Span }, + #[error("Unnecessary `unsafe` block")] + UnnecessaryUnsafeBlock { span: Span }, + #[error("Unnecessary `unsafe` block")] + NestedUnsafeBlock { span: Span }, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -517,6 +521,20 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic { *span, ) }, + TypeCheckError::UnnecessaryUnsafeBlock { span } => { + Diagnostic::simple_warning( + "Unnecessary `unsafe` block".into(), + "".into(), + *span, + ) + }, + TypeCheckError::NestedUnsafeBlock { span } => { + Diagnostic::simple_warning( + "Unnecessary `unsafe` block".into(), + "Because it's nested inside another `unsafe` block".into(), + *span, + ) + }, } } } diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 2d7cf8acca6..5bfcf2a65a4 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3875,3 +3875,43 @@ fn errors_on_cyclic_globals() { CompilationError::ResolverError(ResolverError::DependencyCycle { .. }) ))); } + +#[test] +fn warns_on_unneeded_unsafe() { + let src = r#" + fn main() { + unsafe { + foo() + } + } + + fn foo() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + &errors[0].0, + CompilationError::TypeError(TypeCheckError::UnnecessaryUnsafeBlock { .. }) + )); +} + +#[test] +fn warns_on_nested_unsafe() { + let src = r#" + fn main() { + unsafe { + unsafe { + foo() + } + } + } + + unconstrained fn foo() {} + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + assert!(matches!( + &errors[0].0, + CompilationError::TypeError(TypeCheckError::NestedUnsafeBlock { .. }) + )); +} diff --git a/noir_stdlib/src/collections/map.nr b/noir_stdlib/src/collections/map.nr index 2b0da1b90ec..bc0b80124db 100644 --- a/noir_stdlib/src/collections/map.nr +++ b/noir_stdlib/src/collections/map.nr @@ -271,7 +271,7 @@ impl HashMap { for slot in self._table { if slot.is_valid() { - let (_, value) = unsafe { slot.key_value_unchecked() }; + let (_, value) = slot.key_value_unchecked(); values.push(value); } } diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index 5d2164a510d..f2234300ab2 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -52,7 +52,7 @@ pub comptime fn derive(s: StructDefinition, traits: [TraitDefinition]) -> Quoted let mut result = quote {}; for trait_to_derive in traits { - let handler = unsafe { HANDLERS.get(trait_to_derive) }; + let handler = HANDLERS.get(trait_to_derive); assert(handler.is_some(), f"No derive function registered for `{trait_to_derive}`"); let trait_impl = handler.unwrap()(s);