diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index e1b4a8bcce0..ce8545d2ccc 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -29,7 +29,7 @@ use crate::hir::def_map::{LocalModuleId, ModuleDefId, TryFromModuleDefId, MAIN_F use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern}; use crate::node_interner::{ DefinitionId, DefinitionKind, ExprId, FuncId, NodeInterner, StmtId, StructId, TraitId, - TraitImplId, + TraitImplId, TraitImplKind, }; use crate::{ hir::{def_map::CrateDefMap, resolution::path_resolver::PathResolver}, @@ -1207,8 +1207,12 @@ impl<'a> Resolver<'a> { Literal::Unit => HirLiteral::Unit, }), ExpressionKind::Variable(path) => { - if let Some(expr) = self.resolve_trait_generic_path(&path) { - expr + if let Some((hir_expr, object_type)) = self.resolve_trait_generic_path(&path) { + let expr_id = self.interner.push_expr(hir_expr); + self.interner.push_expr_location(expr_id, expr.span, self.file); + self.interner + .select_impl_for_ident(expr_id, TraitImplKind::Assumed { object_type }); + return expr_id; } else { // If the Path is being used as an Expression, then it is referring to a global from a separate module // Otherwise, then it is referring to an Identifier @@ -1370,6 +1374,8 @@ impl<'a> Resolver<'a> { ExpressionKind::Parenthesized(sub_expr) => return self.resolve_expression(*sub_expr), }; + // If these lines are ever changed, make sure to change the early return + // in the ExpressionKind::Variable case as well let expr_id = self.interner.push_expr(hir_expr); self.interner.push_expr_location(expr_id, expr.span, self.file); expr_id @@ -1576,7 +1582,10 @@ impl<'a> Resolver<'a> { } // this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type) - fn resolve_trait_static_method_by_self(&mut self, path: &Path) -> Option { + fn resolve_trait_static_method_by_self( + &mut self, + path: &Path, + ) -> Option<(HirExpression, Type)> { if let Some(trait_id) = self.trait_id { if path.kind == PathKind::Plain && path.segments.len() == 2 { let name = &path.segments[0].0.contents; @@ -1590,7 +1599,7 @@ impl<'a> Resolver<'a> { the_trait.self_type_typevar, crate::TypeVariableKind::Normal, ); - return Some(HirExpression::TraitMethodReference(self_type, method)); + return Some((HirExpression::TraitMethodReference(method), self_type)); } } } @@ -1599,7 +1608,10 @@ impl<'a> Resolver<'a> { } // this resolves a static trait method T::trait_method by iterating over the where clause - fn resolve_trait_method_by_named_generic(&mut self, path: &Path) -> Option { + fn resolve_trait_method_by_named_generic( + &mut self, + path: &Path, + ) -> Option<(HirExpression, Type)> { if path.segments.len() != 2 { return None; } @@ -1621,7 +1633,7 @@ impl<'a> Resolver<'a> { the_trait.find_method(path.segments.last().unwrap().clone()) { let self_type = self.resolve_type(typ.clone()); - return Some(HirExpression::TraitMethodReference(self_type, method)); + return Some((HirExpression::TraitMethodReference(method), self_type)); } } } @@ -1629,7 +1641,7 @@ impl<'a> Resolver<'a> { None } - fn resolve_trait_generic_path(&mut self, path: &Path) -> Option { + fn resolve_trait_generic_path(&mut self, path: &Path) -> Option<(HirExpression, Type)> { self.resolve_trait_static_method_by_self(path) .or_else(|| self.resolve_trait_method_by_named_generic(path)) } diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index 4a46391f0d4..267dbd6b5be 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -113,6 +113,8 @@ pub enum TypeCheckError { }, #[error("No matching impl found")] NoMatchingImplFound { constraints: Vec<(Type, String)>, span: Span }, + #[error("Constraint for `{typ}: {trait_name}` is not needed, another matching impl is already in scope")] + UnneededTraitConstraint { trait_name: String, typ: Type, span: Span }, } impl TypeCheckError { @@ -269,6 +271,10 @@ impl From for Diagnostic { diagnostic } + TypeCheckError::UnneededTraitConstraint { trait_name, typ, span } => { + let msg = format!("Constraint for `{typ}: {trait_name}` is not needed, another matching impl is already in scope"); + Diagnostic::simple_warning(msg, "Unnecessary trait constraint in where clause".into(), span) + } } } } diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index c0ff4dff6d5..955863a74e0 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -10,7 +10,7 @@ use crate::{ }, types::Type, }, - node_interner::{DefinitionKind, ExprId, FuncId, TraitMethodId}, + node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitMethodId}, BinaryOpKind, Signedness, TypeBinding, TypeVariableKind, UnaryOp, }; @@ -132,7 +132,9 @@ impl<'interner> TypeChecker<'interner> { HirExpression::Index(index_expr) => self.check_index_expression(expr_id, index_expr), HirExpression::Call(call_expr) => { self.check_if_deprecated(&call_expr.func); + let function = self.check_expression(&call_expr.func); + let args = vecmap(&call_expr.arguments, |arg| { let typ = self.check_expression(arg); (typ, *arg, self.interner.expr_span(arg)) @@ -160,21 +162,29 @@ impl<'interner> TypeChecker<'interner> { // so that the backend doesn't need to worry about methods let location = method_call.location; - let mut func_id = None; - if let HirMethodReference::FuncId(id) = method_ref { - func_id = Some(id); - - // Automatically add `&mut` if the method expects a mutable reference and - // the object is not already one. - if id != FuncId::dummy_id() { - let func_meta = self.interner.function_meta(&id); - self.try_add_mutable_reference_to_object( - &mut method_call, - &func_meta.typ, - &mut args, - ); + let trait_id = match &method_ref { + HirMethodReference::FuncId(func_id) => { + // Automatically add `&mut` if the method expects a mutable reference and + // the object is not already one. + if *func_id != FuncId::dummy_id() { + let func_meta = self.interner.function_meta(func_id); + self.try_add_mutable_reference_to_object( + &mut method_call, + &func_meta.typ, + &mut args, + ); + } + + let meta = self.interner.function_meta(func_id); + meta.trait_impl.map(|impl_id| { + self.interner + .get_trait_implementation(impl_id) + .borrow() + .trait_id + }) } - } + HirMethodReference::TraitMethodId(method) => Some(method.trait_id), + }; let (function_id, function_call) = method_call.into_function_call( method_ref.clone(), @@ -185,29 +195,8 @@ impl<'interner> TypeChecker<'interner> { let span = self.interner.expr_span(expr_id); let ret = self.check_method_call(&function_id, method_ref, args, span); - if let Some(func_id) = func_id { - let meta = self.interner.function_meta(&func_id); - - if let Some(impl_id) = meta.trait_impl { - let trait_impl = self.interner.get_trait_implementation(impl_id); - - let result = self.interner.lookup_trait_implementation( - &object_type, - trait_impl.borrow().trait_id, - ); - - if let Err(erroring_constraints) = result { - let constraints = vecmap(erroring_constraints, |constraint| { - let r#trait = self.interner.get_trait(constraint.trait_id); - (constraint.typ, r#trait.name.to_string()) - }); - - self.errors.push(TypeCheckError::NoMatchingImplFound { - constraints, - span, - }); - } - } + if let Some(trait_id) = trait_id { + self.verify_trait_constraint(&object_type, trait_id, function_id, span); } self.interner.replace_expr(expr_id, function_call); @@ -285,7 +274,7 @@ impl<'interner> TypeChecker<'interner> { Type::Function(params, Box::new(lambda.return_type), Box::new(env_type)) } - HirExpression::TraitMethodReference(_, method) => { + HirExpression::TraitMethodReference(method) => { let the_trait = self.interner.get_trait(method.trait_id); let method = &the_trait.methods[method.method_index]; @@ -305,6 +294,26 @@ impl<'interner> TypeChecker<'interner> { typ } + fn verify_trait_constraint( + &mut self, + object_type: &Type, + trait_id: TraitId, + function_ident_id: ExprId, + span: Span, + ) { + match self.interner.lookup_trait_implementation(object_type, trait_id) { + Ok(impl_kind) => self.interner.select_impl_for_ident(function_ident_id, impl_kind), + Err(erroring_constraints) => { + let constraints = vecmap(erroring_constraints, |constraint| { + let r#trait = self.interner.get_trait(constraint.trait_id); + (constraint.typ, r#trait.name.to_string()) + }); + + self.errors.push(TypeCheckError::NoMatchingImplFound { constraints, span }); + } + } + } + /// Check if the given method type requires a mutable reference to the object type, and check /// if the given object type is already a mutable reference. If not, add one. /// This is used to automatically transform a method call: `foo.bar()` into a function @@ -512,13 +521,11 @@ impl<'interner> TypeChecker<'interner> { let func_meta = self.interner.function_meta(&func_id); let param_len = func_meta.parameters.len(); - (func_meta.typ, param_len) } - HirMethodReference::TraitMethodId(_, method) => { + HirMethodReference::TraitMethodId(method) => { let the_trait = self.interner.get_trait(method.trait_id); let method = &the_trait.methods[method.method_index]; - (method.get_type(), method.arguments.len()) } }; @@ -537,7 +544,6 @@ impl<'interner> TypeChecker<'interner> { self.interner.store_instantiation_bindings(*function_ident_id, instantiation_bindings); self.interner.push_expr_type(function_ident_id, function_type.clone()); - self.bind_function_type(function_type, arguments, span) } @@ -889,10 +895,7 @@ impl<'interner> TypeChecker<'interner> { if method.name.0.contents == method_name { let trait_method = TraitMethodId { trait_id: constraint.trait_id, method_index }; - return Some(HirMethodReference::TraitMethodId( - object_type.clone(), - trait_method, - )); + return Some(HirMethodReference::TraitMethodId(trait_method)); } } } diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index 3915132383b..01c68adbb6d 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -43,6 +43,22 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec Vec Vec), Lambda(HirLambda), - TraitMethodReference(Type, TraitMethodId), + TraitMethodReference(TraitMethodId), Error, } @@ -156,7 +156,7 @@ pub enum HirMethodReference { /// Or a method can come from a Trait impl block, in which case /// the actual function called will depend on the instantiated type, /// which can be only known during monomorphization. - TraitMethodId(Type, TraitMethodId), + TraitMethodId(TraitMethodId), } impl HirMethodCallExpression { @@ -174,8 +174,8 @@ impl HirMethodCallExpression { let id = interner.function_definition_id(func_id); HirExpression::Ident(HirIdent { location, id }) } - HirMethodReference::TraitMethodId(typ, method_id) => { - HirExpression::TraitMethodReference(typ, method_id) + HirMethodReference::TraitMethodId(method_id) => { + HirExpression::TraitMethodReference(method_id) } }; let func = interner.push_expr(expr); diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 2a762a38fc6..062f505c179 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -151,5 +151,6 @@ impl TraitFunction { Box::new(self.return_type.clone()), Box::new(Type::Unit), ) + .generalize() } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index e2951f9185e..8ad38b526de 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -575,6 +575,32 @@ impl Type { _ => 0, } } + + /// Takes a monomorphic type and generalizes it over each of the given type variables. + pub(crate) fn generalize_from_variables( + self, + type_vars: HashMap, + ) -> Type { + let polymorphic_type_vars = vecmap(type_vars, |type_var| type_var); + Type::Forall(polymorphic_type_vars, Box::new(self)) + } + + /// Takes a monomorphic type and generalizes it over each of the type variables in the + /// given type bindings, ignoring what each type variable is bound to in the TypeBindings. + pub(crate) fn generalize_from_substitutions(self, type_bindings: TypeBindings) -> Type { + let polymorphic_type_vars = vecmap(type_bindings, |(id, (type_var, _))| (id, type_var)); + Type::Forall(polymorphic_type_vars, Box::new(self)) + } + + /// Takes a monomorphic type and generalizes it over each type variable found within. + /// + /// Note that Noir's type system assumes any Type::Forall are only present at top-level, + /// and thus all type variable's within a type are free. + pub(crate) fn generalize(self) -> Type { + let mut type_variables = HashMap::new(); + self.find_all_unbound_type_variables(&mut type_variables); + self.generalize_from_variables(type_variables) + } } impl std::fmt::Display for Type { @@ -926,8 +952,24 @@ impl Type { } } + (NamedGeneric(binding, _), other) if !binding.borrow().is_unbound() => { + if let TypeBinding::Bound(link) = &*binding.borrow() { + link.try_unify(other) + } else { + unreachable!("If guard ensures binding is bound") + } + } + + (other, NamedGeneric(binding, _)) if !binding.borrow().is_unbound() => { + if let TypeBinding::Bound(link) = &*binding.borrow() { + other.try_unify(link) + } else { + unreachable!("If guard ensures binding is bound") + } + } + (NamedGeneric(binding_a, name_a), NamedGeneric(binding_b, name_b)) => { - // Ensure NamedGenerics are never bound during type checking + // Unbound NamedGenerics are caught by the checks above assert!(binding_a.borrow().is_unbound()); assert!(binding_b.borrow().is_unbound()); @@ -1085,12 +1127,18 @@ impl Type { } /// Replace each NamedGeneric (and TypeVariable) in this type with a fresh type variable - pub(crate) fn instantiate_named_generics( + pub(crate) fn instantiate_type_variables( &self, interner: &NodeInterner, ) -> (Type, TypeBindings) { - let mut substitutions = HashMap::new(); - self.find_all_unbound_type_variables(interner, &mut substitutions); + let mut type_variables = HashMap::new(); + self.find_all_unbound_type_variables(&mut type_variables); + + let substitutions = type_variables + .into_iter() + .map(|(id, type_var)| (id, (type_var, interner.next_type_variable()))) + .collect(); + (self.substitute(&substitutions), substitutions) } @@ -1098,8 +1146,7 @@ impl Type { /// to bind the unbound type variable to a fresh type variable. fn find_all_unbound_type_variables( &self, - interner: &NodeInterner, - bindings: &mut TypeBindings, + type_variables: &mut HashMap, ) { match self { Type::FieldElement @@ -1111,44 +1158,43 @@ impl Type { | Type::NotConstant | Type::Error => (), Type::Array(length, elem) => { - length.find_all_unbound_type_variables(interner, bindings); - elem.find_all_unbound_type_variables(interner, bindings); + length.find_all_unbound_type_variables(type_variables); + elem.find_all_unbound_type_variables(type_variables); } - Type::String(length) => length.find_all_unbound_type_variables(interner, bindings), + Type::String(length) => length.find_all_unbound_type_variables(type_variables), Type::FmtString(length, env) => { - length.find_all_unbound_type_variables(interner, bindings); - env.find_all_unbound_type_variables(interner, bindings); + length.find_all_unbound_type_variables(type_variables); + env.find_all_unbound_type_variables(type_variables); } Type::Struct(_, generics) => { for generic in generics { - generic.find_all_unbound_type_variables(interner, bindings); + generic.find_all_unbound_type_variables(type_variables); } } Type::Tuple(fields) => { for field in fields { - field.find_all_unbound_type_variables(interner, bindings); + field.find_all_unbound_type_variables(type_variables); } } Type::Function(args, ret, env) => { for arg in args { - arg.find_all_unbound_type_variables(interner, bindings); + arg.find_all_unbound_type_variables(type_variables); } - ret.find_all_unbound_type_variables(interner, bindings); - env.find_all_unbound_type_variables(interner, bindings); + ret.find_all_unbound_type_variables(type_variables); + env.find_all_unbound_type_variables(type_variables); } Type::MutableReference(elem) => { - elem.find_all_unbound_type_variables(interner, bindings); + elem.find_all_unbound_type_variables(type_variables); } - Type::Forall(_, typ) => typ.find_all_unbound_type_variables(interner, bindings), + Type::Forall(_, typ) => typ.find_all_unbound_type_variables(type_variables), Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => { match &*type_variable.borrow() { TypeBinding::Bound(binding) => { - binding.find_all_unbound_type_variables(interner, bindings); + binding.find_all_unbound_type_variables(type_variables); } TypeBinding::Unbound(id) => { - if !bindings.contains_key(id) { - let fresh_type_variable = interner.next_type_variable(); - bindings.insert(*id, (type_variable.clone(), fresh_type_variable)); + if !type_variables.contains_key(id) { + type_variables.insert(*id, type_variable.clone()); } } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 44d734cd5f9..634569bbc7a 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -24,7 +24,7 @@ use crate::{ stmt::{HirAssignStatement, HirLValue, HirLetStatement, HirPattern, HirStatement}, types, }, - node_interner::{self, DefinitionKind, NodeInterner, StmtId, TraitMethodId}, + node_interner::{self, DefinitionKind, NodeInterner, StmtId, TraitImplKind, TraitMethodId}, token::FunctionAttribute, ContractFunctionType, FunctionKind, Type, TypeBinding, TypeBindings, TypeVariableKind, Visibility, @@ -378,9 +378,9 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Lambda(lambda) => self.lambda(lambda, expr), - HirExpression::TraitMethodReference(typ, method) => { + HirExpression::TraitMethodReference(method) => { if let Type::Function(_, _, _) = self.interner.id_type(expr) { - self.resolve_trait_method_reference(typ, expr, method) + self.resolve_trait_method_reference(expr, method) } else { unreachable!( "Calling a non-function, this should've been caught in typechecking" @@ -812,7 +812,6 @@ impl<'interner> Monomorphizer<'interner> { fn resolve_trait_method_reference( &mut self, - self_type: HirType, expr_id: node_interner::ExprId, method: TraitMethodId, ) -> ast::Expression { @@ -820,10 +819,29 @@ impl<'interner> Monomorphizer<'interner> { let trait_impl = self .interner - .lookup_trait_implementation(&self_type, method.trait_id) + .get_selected_impl_for_ident(expr_id) .expect("ICE: missing trait impl - should be caught during type checking"); - let hir_func_id = trait_impl.borrow().methods[method.method_index]; + let hir_func_id = match trait_impl { + node_interner::TraitImplKind::Normal(impl_id) => { + self.interner.get_trait_implementation(impl_id).borrow().methods + [method.method_index] + } + node_interner::TraitImplKind::Assumed { object_type } => { + match self.interner.lookup_trait_implementation(&object_type, method.trait_id) { + Ok(TraitImplKind::Normal(impl_id)) => { + self.interner.get_trait_implementation(impl_id).borrow().methods + [method.method_index] + } + Ok(TraitImplKind::Assumed { .. }) => unreachable!( + "There should be no remaining Assumed impls during monomorphization" + ), + Err(constraints) => { + unreachable!("Failed to find trait impl during monomorphization. The failed constraint(s) are:\n {constraints:?}") + } + } + } + }; let func_def = self.lookup_function(hir_func_id, expr_id, &function_type); let func_id = match func_def { diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index d03ed5528d9..e4532e2dceb 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -101,7 +101,13 @@ pub struct NodeInterner { /// we cannot map from Type directly to impl, we need to iterate a Vec of all impls /// of that trait to see if any type may match. This can be further optimized later /// by splitting it up by type. - trait_implementation_map: HashMap>, + trait_implementation_map: HashMap>, + + /// When impls are found during type checking, we tag the function call's Ident + /// with the impl that was selected. For cases with where clauses, this may be + /// an Assumed (but verified) impl. In this case the monomorphizer should have + /// the context to get the concrete type of the object and select the correct impl itself. + selected_trait_implementations: HashMap, /// Map from ExprId (referring to a Function/Method call) to its corresponding TypeBindings, /// filled out during type checking from instantiated variables. Used during monomorphization @@ -131,6 +137,18 @@ pub struct NodeInterner { func_id_to_trait: HashMap, } +/// A trait implementation is either a normal implementation that is present in the source +/// program via an `impl` block, or it is assumed to exist from a `where` clause or similar. +#[derive(Debug, Clone)] +pub enum TraitImplKind { + Normal(TraitImplId), + + /// Assumed impls don't have an impl id since they don't link back to any concrete part of the source code. + Assumed { + object_type: Type, + }, +} + /// Represents the methods on a given type that each share the same name. /// /// Methods are split into inherent methods and trait methods. If there is @@ -405,6 +423,7 @@ impl Default for NodeInterner { traits: HashMap::new(), trait_implementations: Vec::new(), trait_implementation_map: HashMap::new(), + selected_trait_implementations: HashMap::new(), instantiation_bindings: HashMap::new(), field_indices: HashMap::new(), next_type_variable_id: std::cell::Cell::new(0), @@ -960,7 +979,7 @@ impl NodeInterner { &self, object_type: &Type, trait_id: TraitId, - ) -> Result, Vec> { + ) -> Result> { self.lookup_trait_implementation_helper(object_type, trait_id, IMPL_SEARCH_RECURSION_LIMIT) } @@ -969,7 +988,7 @@ impl NodeInterner { object_type: &Type, trait_id: TraitId, recursion_limit: u32, - ) -> Result, Vec> { + ) -> Result> { let make_constraint = || TraitConstraint::new(object_type.clone(), trait_id); // Prevent infinite recursion when looking for impls @@ -980,23 +999,25 @@ impl NodeInterner { let impls = self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; - for (existing_object_type, impl_id) in impls { - let (existing_object_type, type_bindings) = - existing_object_type.instantiate_named_generics(self); + for (existing_object_type, impl_kind) in impls { + let (existing_object_type, type_bindings) = existing_object_type.instantiate(self); if object_type.try_unify(&existing_object_type).is_ok() { - let trait_impl = self.get_trait_implementation(*impl_id); - - if let Err(mut errors) = self.validate_where_clause( - &trait_impl.borrow().where_clause, - &type_bindings, - recursion_limit, - ) { - errors.push(make_constraint()); - return Err(errors); + if let TraitImplKind::Normal(impl_id) = impl_kind { + let trait_impl = self.get_trait_implementation(*impl_id); + let trait_impl = trait_impl.borrow(); + + if let Err(mut errors) = self.validate_where_clause( + &trait_impl.where_clause, + &type_bindings, + recursion_limit, + ) { + errors.push(make_constraint()); + return Err(errors); + } } - return Ok(trait_impl); + return Ok(impl_kind.clone()); } } @@ -1022,6 +1043,30 @@ impl NodeInterner { Ok(()) } + /// Adds an "assumed" trait implementation to the currently known trait implementations. + /// Unlike normal trait implementations, these are only assumed to exist. They often correspond + /// to `where` clauses in functions where we assume there is some `T: Eq` even though we do + /// not yet know T. For these cases, we store an impl here so that we assume they exist and + /// can resolve them. They are then later verified when the function is called, and linked + /// properly after being monomorphized to the correct variant. + /// + /// Returns true on success, or false if there is already an overlapping impl in scope. + pub fn add_assumed_trait_implementation( + &mut self, + object_type: Type, + trait_id: TraitId, + ) -> bool { + // Make sure there are no overlapping impls + if self.lookup_trait_implementation(&object_type, trait_id).is_ok() { + return false; + } + + let entries = self.trait_implementation_map.entry(trait_id).or_default(); + entries.push((object_type.clone(), TraitImplKind::Assumed { object_type })); + true + } + + /// Adds a trait implementation to the list of known implementations. pub fn add_trait_implementation( &mut self, object_type: Type, @@ -1033,10 +1078,17 @@ impl NodeInterner { self.trait_implementations.push(trait_impl.clone()); - let (instantiated_object_type, _) = object_type.instantiate_named_generics(self); - if let Ok(existing_impl) = + // Ignoring overlapping TraitImplKind::Assumed impls here is perfectly fine. + // It should never happen since impls are defined at global scope, but even + // if they were, we should never prevent defining a new impl because a where + // clause already assumes it exists. + let (instantiated_object_type, substitutions) = + object_type.instantiate_type_variables(self); + + if let Ok(TraitImplKind::Normal(existing)) = self.lookup_trait_implementation(&instantiated_object_type, trait_id) { + let existing_impl = self.get_trait_implementation(existing); let existing_impl = existing_impl.borrow(); return Err((existing_impl.ident.span(), existing_impl.file)); } @@ -1046,8 +1098,11 @@ impl NodeInterner { self.add_method(&object_type, method_name, *method, true); } + // The object type is generalized so that a generic impl will apply + // to any type T, rather than just the generic type named T. + let generalized_object_type = object_type.generalize_from_substitutions(substitutions); let entries = self.trait_implementation_map.entry(trait_id).or_default(); - entries.push((object_type, impl_id)); + entries.push((generalized_object_type, TraitImplKind::Normal(impl_id))); Ok(()) } @@ -1119,9 +1174,29 @@ impl NodeInterner { /// Returns what the next trait impl id is expected to be. /// Note that this does not actually reserve the slot so care should /// be taken that the next trait impl added matches this ID. - pub(crate) fn next_trait_impl_id(&self) -> TraitImplId { + pub fn next_trait_impl_id(&self) -> TraitImplId { TraitImplId(self.trait_implementations.len()) } + + /// Removes all TraitImplKind::Assumed from the list of known impls for the given trait + pub fn remove_assumed_trait_implementations_for_trait(&mut self, trait_id: TraitId) { + let entries = self.trait_implementation_map.entry(trait_id).or_default(); + entries.retain(|(_, kind)| matches!(kind, TraitImplKind::Normal(_))); + } + + /// Tags the given identifier with the selected trait_impl so that monomorphization + /// can later recover which impl was selected, or alternatively see if it needs to + /// decide which impl to select (because the impl was Assumed). + pub fn select_impl_for_ident(&mut self, ident_id: ExprId, trait_impl: TraitImplKind) { + self.selected_trait_implementations.insert(ident_id, trait_impl); + } + + /// Tags the given identifier with the selected trait_impl so that monomorphization + /// can later recover which impl was selected, or alternatively see if it needs to + /// decide which (because the impl was Assumed). + pub fn get_selected_impl_for_ident(&self, ident_id: ExprId) -> Option { + self.selected_trait_implementations.get(&ident_id).cloned() + } } impl Methods { diff --git a/tooling/nargo_cli/tests/compile_success_empty/impl_with_where_clause/src/main.nr b/tooling/nargo_cli/tests/compile_success_empty/impl_with_where_clause/src/main.nr index 0a4bd65eb0c..7f7080dff43 100644 --- a/tooling/nargo_cli/tests/compile_success_empty/impl_with_where_clause/src/main.nr +++ b/tooling/nargo_cli/tests/compile_success_empty/impl_with_where_clause/src/main.nr @@ -1,12 +1,11 @@ fn main() { - // Test is temporarily disabled, see #3409 - // let array: [Field; 3] = [1, 2, 3]; - // assert(array.eq(array)); + let array: [Field; 3] = [1, 2, 3]; + assert(array.eq(array)); - // // Ensure this still works if we have to infer the type of the integer literals - // let array = [1, 2, 3]; - // assert(array.eq(array)); + // Ensure this still works if we have to infer the type of the integer literals + let array = [1, 2, 3]; + assert(array.eq(array)); } trait Eq { diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_self/trait_self/Nargo.toml b/tooling/nargo_cli/tests/compile_success_empty/trait_self/trait_self/Nargo.toml deleted file mode 100644 index 71c541ccd4f..00000000000 --- a/tooling/nargo_cli/tests/compile_success_empty/trait_self/trait_self/Nargo.toml +++ /dev/null @@ -1,6 +0,0 @@ -[package] -name = "trait_self" -type = "bin" -authors = [""] - -[dependencies] \ No newline at end of file diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_self/trait_self/src/main.nr b/tooling/nargo_cli/tests/compile_success_empty/trait_self/trait_self/src/main.nr deleted file mode 100644 index f4f73822cc3..00000000000 --- a/tooling/nargo_cli/tests/compile_success_empty/trait_self/trait_self/src/main.nr +++ /dev/null @@ -1,18 +0,0 @@ -struct Foo { - x: Field -} - -trait Asd { - fn asd() -> Self; -} - -impl Asd for Foo { - // the Self should typecheck properly - fn asd() -> Self { - Foo{x: 100} - } -} - -fn main() { - assert(Foo::asd().x == 100); -} \ No newline at end of file diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_self/Nargo.toml b/tooling/nargo_cli/tests/compile_success_empty/trait_static_methods/Nargo.toml similarity index 100% rename from tooling/nargo_cli/tests/compile_success_empty/trait_self/Nargo.toml rename to tooling/nargo_cli/tests/compile_success_empty/trait_static_methods/Nargo.toml diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_self/src/main.nr b/tooling/nargo_cli/tests/compile_success_empty/trait_static_methods/src/main.nr similarity index 100% rename from tooling/nargo_cli/tests/compile_success_empty/trait_self/src/main.nr rename to tooling/nargo_cli/tests/compile_success_empty/trait_static_methods/src/main.nr diff --git a/tooling/nargo_cli/tests/compile_success_empty/trait_where_clause/src/main.nr b/tooling/nargo_cli/tests/compile_success_empty/trait_where_clause/src/main.nr index 001bd5a6ec6..1ec736c17e3 100644 --- a/tooling/nargo_cli/tests/compile_success_empty/trait_where_clause/src/main.nr +++ b/tooling/nargo_cli/tests/compile_success_empty/trait_where_clause/src/main.nr @@ -25,7 +25,7 @@ impl Asd for AddXY { struct Static100 {} impl StaticTrait for Static100 { - // use default implementatino for static_function, which returns 100 + // use default implementation for static_function, which returns 100 } struct Static200 {} @@ -47,12 +47,11 @@ fn main() { let a = Add30{ x: 70 }; let xy = AddXY{ x: 30, y: 70 }; - // Temporarily disabled, see #3409 - // assert_asd_eq_100(x); - // assert_asd_eq_100(z); - // assert_asd_eq_100(a); - // assert_asd_eq_100(xy); + assert_asd_eq_100(x); + assert_asd_eq_100(z); + assert_asd_eq_100(a); + assert_asd_eq_100(xy); - // assert(add_one_to_static_function(Static100{}) == 101); - // assert(add_one_to_static_function(Static200{}) == 201); + assert(add_one_to_static_function(Static100{}) == 101); + assert(add_one_to_static_function(Static200{}) == 201); }