diff --git a/pyo3-derive-backend/src/pyclass.rs b/pyo3-derive-backend/src/pyclass.rs index 5f4838bfaee..23697a76009 100644 --- a/pyo3-derive-backend/src/pyclass.rs +++ b/pyo3-derive-backend/src/pyclass.rs @@ -56,72 +56,65 @@ impl PyClassArgs { match expr { syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => self.add_path(exp), syn::Expr::Assign(ref assign) => self.add_assign(assign), - _ => Err(syn::Error::new_spanned(expr, "Could not parse arguments")), + _ => Err(syn::Error::new_spanned(expr, "Failed to parse arguments")), } } /// Match a single flag fn add_assign(&mut self, assign: &syn::ExprAssign) -> syn::Result<()> { - let key = match *assign.left { - syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => { + let syn::ExprAssign { left, right, .. } = assign; + let key = match &**left { + syn::Expr::Path(exp) if exp.path.segments.len() == 1 => { exp.path.segments.first().unwrap().ident.to_string() } _ => { - return Err(syn::Error::new_spanned(assign, "could not parse argument")); + return Err(syn::Error::new_spanned(assign, "Failed to parse arguments")); } }; + macro_rules! expected { + ($expected: literal) => { + expected!($expected, right) + }; + ($expected: literal, $span: ident) => { + return Err(syn::Error::new_spanned( + $span, + concat!("Expected ", $expected), + )); + }; + } + match key.as_str() { "freelist" => { // We allow arbitrary expressions here so you can e.g. use `8*64` - self.freelist = Some(*assign.right.clone()); + self.freelist = Some(syn::Expr::clone(right)); } - "name" => match *assign.right { - syn::Expr::Path(ref exp) if exp.path.segments.len() == 1 => { + "name" => match &**right { + syn::Expr::Path(exp) if exp.path.segments.len() == 1 => { self.name = Some(exp.clone().into()); } - _ => { - return Err(syn::Error::new_spanned( - *assign.right.clone(), - "Wrong 'name' format", - )); - } + _ => expected!("type name (e.g., Name)"), }, - "extends" => match *assign.right { - syn::Expr::Path(ref exp) => { + "extends" => match &**right { + syn::Expr::Path(exp) => { self.base = syn::TypePath { path: exp.path.clone(), qself: None, }; self.has_extends = true; } - _ => { - return Err(syn::Error::new_spanned( - *assign.right.clone(), - "Wrong format for extends", - )); - } + _ => expected!("type path (e.g., my_mod::BaseClass)"), }, - "module" => match *assign.right { + "module" => match &**right { syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(ref lit), + lit: syn::Lit::Str(lit), .. }) => { self.module = Some(lit.clone()); } - _ => { - return Err(syn::Error::new_spanned( - *assign.right.clone(), - "Wrong format for module", - )); - } + _ => expected!(r#"string literal (e.g., "my_mod")"#), }, - _ => { - return Err(syn::Error::new_spanned( - *assign.left.clone(), - "Unsupported parameter", - )); - } + _ => expected!("one of freelist/name/extends/module", left), }; Ok(()) @@ -145,9 +138,9 @@ impl PyClassArgs { } _ => { return Err(syn::Error::new_spanned( - exp.path.clone(), - "Unsupported parameter", - )); + &exp.path, + "Expected one of gc/weakref/subclass/dict", + )) } }; diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index c0cf995ac23..1d6358dccfc 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -3,6 +3,7 @@ fn test_compile_errors() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/invalid_macro_args.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); + t.compile_fail("tests/ui/invalid_pyclass_args.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); t.compile_fail("tests/ui/missing_clone.rs"); t.compile_fail("tests/ui/reject_generics.rs"); diff --git a/tests/ui/invalid_pyclass_args.rs b/tests/ui/invalid_pyclass_args.rs new file mode 100644 index 00000000000..57aa3a09a91 --- /dev/null +++ b/tests/ui/invalid_pyclass_args.rs @@ -0,0 +1,18 @@ +use pyo3::prelude::*; + +#[pyclass(extend=pyo3::types::PyDict)] +struct TypoIntheKey {} + +#[pyclass(extends = "PyDict")] +struct InvalidExtends {} + +#[pyclass(name = m::MyClass)] +struct InvalidName {} + +#[pyclass(module = my_module)] +struct InvalidModule {} + +#[pyclass(weakrev)] +struct InvalidArg {} + +fn main() {} diff --git a/tests/ui/invalid_pyclass_args.stderr b/tests/ui/invalid_pyclass_args.stderr new file mode 100644 index 00000000000..72373cd6d3e --- /dev/null +++ b/tests/ui/invalid_pyclass_args.stderr @@ -0,0 +1,29 @@ +error: Expected one of freelist/name/extends/module + --> $DIR/invalid_pyclass_args.rs:3:11 + | +3 | #[pyclass(extend=pyo3::types::PyDict)] + | ^^^^^^ + +error: Expected type path (e.g., my_mod::BaseClass) + --> $DIR/invalid_pyclass_args.rs:6:21 + | +6 | #[pyclass(extends = "PyDict")] + | ^^^^^^^^ + +error: Expected type name (e.g., Name) + --> $DIR/invalid_pyclass_args.rs:9:18 + | +9 | #[pyclass(name = m::MyClass)] + | ^^^^^^^^^^ + +error: Expected string literal (e.g., "my_mod") + --> $DIR/invalid_pyclass_args.rs:12:20 + | +12 | #[pyclass(module = my_module)] + | ^^^^^^^^^ + +error: Expected one of gc/weakref/subclass/dict + --> $DIR/invalid_pyclass_args.rs:15:11 + | +15 | #[pyclass(weakrev)] + | ^^^^^^^