Skip to content

Commit 4d5fb4f

Browse files
committed
Merge pull request #4054 from weiznich/fix/auto_type_lifetimes
Fix lifetimes with `#[auto_type]`
1 parent 1062c27 commit 4d5fb4f

File tree

8 files changed

+228
-4
lines changed

8 files changed

+228
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use diesel::dsl::*;
2+
use diesel::prelude::*;
3+
4+
diesel::table! {
5+
users {
6+
id -> Integer,
7+
name -> Text,
8+
}
9+
}
10+
11+
#[auto_type]
12+
fn with_lifetime(name: &'_ str) -> _ {
13+
users::table.filter(users::name.eq(name))
14+
}
15+
16+
#[auto_type]
17+
fn with_lifetime2(name: &str) -> _ {
18+
users::table.filter(users::name.eq(name))
19+
}
20+
21+
fn main() {
22+
println!("Hello, world!");
23+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
error: `#[auto_type]` requires named lifetimes
2+
--> tests/fail/auto_type_life_times.rs:12:25
3+
|
4+
12 | fn with_lifetime(name: &'_ str) -> _ {
5+
| ^^
6+
7+
error: `#[auto_type]` requires named lifetimes
8+
--> tests/fail/auto_type_life_times.rs:17:25
9+
|
10+
17 | fn with_lifetime2(name: &str) -> _ {
11+
| ^^^^
12+
13+
error[E0106]: missing lifetime specifier
14+
--> tests/fail/auto_type_life_times.rs:12:25
15+
|
16+
12 | fn with_lifetime(name: &'_ str) -> _ {
17+
| ^^ expected named lifetime parameter
18+
|
19+
help: consider introducing a named lifetime parameter
20+
|
21+
12 | fn with_lifetime<'a>(name: &'a str) -> _ {
22+
| ++++ ~~
23+
24+
error[E0106]: missing lifetime specifier
25+
--> tests/fail/auto_type_life_times.rs:17:25
26+
|
27+
17 | fn with_lifetime2(name: &str) -> _ {
28+
| ^ expected named lifetime parameter
29+
|
30+
help: consider introducing a named lifetime parameter
31+
|
32+
17 | fn with_lifetime2<'a>(name: &'a str) -> _ {
33+
| ++++ ++
34+
35+
error: lifetime may not live long enough
36+
--> tests/fail/auto_type_life_times.rs:13:5
37+
|
38+
12 | fn with_lifetime(name: &'_ str) -> _ {
39+
| - let's call the lifetime of this reference `'1`
40+
13 | users::table.filter(users::name.eq(name))
41+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'static`
42+
43+
error: lifetime may not live long enough
44+
--> tests/fail/auto_type_life_times.rs:18:5
45+
|
46+
17 | fn with_lifetime2(name: &str) -> _ {
47+
| - let's call the lifetime of this reference `'1`
48+
18 | users::table.filter(users::name.eq(name))
49+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'static`

diesel_derives/tests/auto_type.rs

+18
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,24 @@ fn test_normal_functions() -> _ {
355355
))
356356
}
357357

358+
#[auto_type]
359+
fn with_lifetime<'a>(name: &'a str) -> _ {
360+
users::table.filter(users::name.eq(name))
361+
}
362+
363+
#[auto_type]
364+
fn with_type_generics<'a, T>(name: &'a T) -> _
365+
where
366+
&'a T: diesel::expression::AsExpression<diesel::sql_types::Text>,
367+
{
368+
users::name.eq(name)
369+
}
370+
371+
#[auto_type]
372+
fn with_const_generics<const N: i32>() -> _ {
373+
users::id.eq(N)
374+
}
375+
358376
// #[auto_type]
359377
// fn test_sql_fragment() -> _ {
360378
// sql("foo")

dsl_auto_type/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ either = "1"
1515
heck = "0.5"
1616
proc-macro2 = "1"
1717
quote = "1"
18-
syn = { version = "2", features = ["extra-traits", "full", "derive", "parsing"] }
18+
syn = { version = "2", features = ["extra-traits", "full", "derive", "parsing", "visit"] }
1919

2020
[dev-dependencies]
2121
diesel = { path = "../diesel" }

dsl_auto_type/src/auto_type/expression_type_inference.rs

+2
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ impl TypeInferrer<'_> {
7373
Err(e) => self.register_error(e, expr.span()),
7474
}
7575
}
76+
7677
fn register_error(&self, error: syn::Error, infer_type_span: Span) -> syn::Type {
7778
self.errors.borrow_mut().push(Rc::new(error));
7879
parse_quote_spanned!(infer_type_span=> _)
7980
}
81+
8082
fn try_infer_expression_type(
8183
&self,
8284
expr: &syn::Expr,

dsl_auto_type/src/auto_type/local_variables_map.rs

+10
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,16 @@ impl<'a, 'p> LocalVariablesMap<'a, 'p> {
116116
Ok(())
117117
}
118118

119+
pub(crate) fn process_const_generic(&mut self, const_generic: &'a syn::ConstParam) {
120+
self.inner.map.insert(
121+
&const_generic.ident,
122+
LetStatementInferredType {
123+
type_: const_generic.ty.clone(),
124+
errors: Vec::new(),
125+
},
126+
);
127+
}
128+
119129
/// Finishes a block inference for this map.
120130
/// It may be initialized with `pat`s before (such as function parameters),
121131
/// then this function is used to infer the type of the last expression in the block.

dsl_auto_type/src/auto_type/mod.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod case;
22
pub mod expression_type_inference;
33
mod local_variables_map;
4+
mod referenced_generics;
45
mod settings_builder;
56

67
use {
@@ -134,6 +135,9 @@ pub(crate) fn auto_type_impl(
134135
parent: None,
135136
},
136137
};
138+
for const_generic in input_function.sig.generics.const_params() {
139+
local_variables_map.process_const_generic(const_generic);
140+
}
137141
for function_param in &input_function.sig.inputs {
138142
if let syn::FnArg::Typed(syn::PatType { pat, ty, .. }) = function_param {
139143
match local_variables_map.process_pat(pat, Some(ty), None) {
@@ -165,11 +169,19 @@ pub(crate) fn auto_type_impl(
165169

166170
let type_alias = match type_alias {
167171
Some(type_alias) => {
172+
// We're generating a type alias so we need to extract the necessary lifetimes and
173+
// generic type parameters for that type alias
174+
let type_alias_generics = referenced_generics::extract_referenced_generics(
175+
&return_type,
176+
&input_function.sig.generics,
177+
&mut errors,
178+
);
179+
168180
let vis = &input_function.vis;
169-
input_function.sig.output = parse_quote!(-> #type_alias);
181+
input_function.sig.output = parse_quote!(-> #type_alias #type_alias_generics);
170182
quote! {
171183
#[allow(non_camel_case_types)]
172-
#vis type #type_alias = #return_type;
184+
#vis type #type_alias #type_alias_generics = #return_type;
173185
}
174186
}
175187
None => {
@@ -180,12 +192,13 @@ pub(crate) fn auto_type_impl(
180192

181193
let mut res = quote! {
182194
#type_alias
195+
#[allow(clippy::needless_lifetimes)]
183196
#input_function
184197
};
185198

186199
for error in errors {
187200
// Extracting from the `Rc` only if it's the last reference is an elegant way to
188-
// deduplicate errors For this to work it is necessary that the rest of
201+
// deduplicate errors. For this to work it is necessary that the rest of
189202
// the errors (those from the local variables map that weren't used) are
190203
// dropped before, which is the case here, and that we are iterating on the
191204
// errors in an owned manner.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
use std::rc::Rc;
2+
use syn::parse_quote;
3+
use syn::visit::{self, Visit};
4+
use syn::{Ident, Lifetime};
5+
6+
pub(crate) fn extract_referenced_generics(
7+
ty: &syn::Type,
8+
generics: &syn::Generics,
9+
errors: &mut Vec<Rc<syn::Error>>,
10+
) -> syn::Generics {
11+
struct Visitor<'g, 'errs> {
12+
lifetimes: Vec<(&'g Lifetime, bool)>,
13+
type_parameters: Vec<(&'g Ident, bool)>,
14+
errors: &'errs mut Vec<Rc<syn::Error>>,
15+
}
16+
17+
let mut visitor = Visitor {
18+
lifetimes: generics
19+
.lifetimes()
20+
.map(|lt| (&lt.lifetime, false))
21+
.collect(),
22+
type_parameters: generics
23+
.type_params()
24+
.map(|tp| (&tp.ident, false))
25+
.collect(),
26+
errors,
27+
};
28+
visitor.lifetimes.sort_unstable();
29+
visitor.type_parameters.sort_unstable();
30+
31+
impl<'ast> Visit<'ast> for Visitor<'_, '_> {
32+
fn visit_lifetime(&mut self, lifetime: &'ast Lifetime) {
33+
if lifetime.ident == "_" {
34+
self.errors.push(Rc::new(syn::Error::new_spanned(
35+
lifetime,
36+
"`#[auto_type]` requires named lifetimes",
37+
)));
38+
} else if lifetime.ident != "static" {
39+
if let Ok(lifetime_idx) = self
40+
.lifetimes
41+
.binary_search_by_key(&lifetime, |(lt, _)| *lt)
42+
{
43+
self.lifetimes[lifetime_idx].1 = true;
44+
}
45+
}
46+
visit::visit_lifetime(self, lifetime)
47+
}
48+
49+
fn visit_type_reference(&mut self, reference: &'ast syn::TypeReference) {
50+
if reference.lifetime.is_none() {
51+
self.errors.push(Rc::new(syn::Error::new_spanned(
52+
reference,
53+
"`#[auto_type]` requires named lifetimes",
54+
)));
55+
}
56+
visit::visit_type_reference(self, reference)
57+
}
58+
59+
fn visit_type_path(&mut self, type_path: &'ast syn::TypePath) {
60+
if let Some(path_ident) = type_path.path.get_ident() {
61+
if let Ok(type_param_idx) = self
62+
.type_parameters
63+
.binary_search_by_key(&path_ident, |tp| tp.0)
64+
{
65+
self.type_parameters[type_param_idx].1 = true;
66+
}
67+
}
68+
visit::visit_type_path(self, type_path)
69+
}
70+
}
71+
72+
visitor.visit_type(ty);
73+
74+
let generic_params: syn::punctuated::Punctuated<syn::GenericParam, _> = generics
75+
.params
76+
.iter()
77+
.filter_map(|param| match param {
78+
syn::GenericParam::Lifetime(lt)
79+
if visitor
80+
.lifetimes
81+
.binary_search(&(&lt.lifetime, true))
82+
.is_ok() =>
83+
{
84+
let lt = &lt.lifetime;
85+
Some(parse_quote!(#lt))
86+
}
87+
syn::GenericParam::Type(tp)
88+
if visitor
89+
.type_parameters
90+
.binary_search(&(&tp.ident, true))
91+
.is_ok() =>
92+
{
93+
let ident = &tp.ident;
94+
Some(parse_quote!(#ident))
95+
}
96+
_ => None::<syn::GenericParam>,
97+
})
98+
.collect();
99+
100+
// We need to not set the lt_token and gt_token if `params` is empty to get
101+
// a reasonable error message for the case that there is no lifetime specifier
102+
// but we need one
103+
syn::Generics {
104+
lt_token: (!generic_params.is_empty()).then(Default::default),
105+
gt_token: (!generic_params.is_empty()).then(Default::default),
106+
params: generic_params,
107+
where_clause: None,
108+
}
109+
}

0 commit comments

Comments
 (0)