Skip to content

Commit 2653a27

Browse files
committed
Merge pull request #4115 from weiznich/fix/sqlite_row_iter_shouldn_t_panic_if_called_again_after_error
Fixed a potential panic in SQLite row iterators
1 parent e380c52 commit 2653a27

File tree

3 files changed

+194
-110
lines changed

3 files changed

+194
-110
lines changed

diesel/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ cfg-if = "1"
5454
dotenvy = "0.15"
5555
ipnetwork = ">=0.12.2, <0.21.0"
5656
quickcheck = "1.0.3"
57+
tempfile = "3.10.1"
5758

5859
[features]
5960
default = ["with-deprecated", "32-column-tables"]

diesel/src/sqlite/connection/row.rs

+186-103
Original file line numberDiff line numberDiff line change
@@ -200,126 +200,209 @@ impl<'stmt, 'query> Field<'stmt, Sqlite> for SqliteField<'stmt, 'query> {
200200
}
201201
}
202202

203-
#[test]
204-
fn fun_with_row_iters() {
205-
crate::table! {
206-
#[allow(unused_parens)]
207-
users(id) {
208-
id -> Integer,
209-
name -> Text,
203+
#[cfg(test)]
204+
mod tests {
205+
use super::*;
206+
207+
#[test]
208+
fn fun_with_row_iters() {
209+
crate::table! {
210+
#[allow(unused_parens)]
211+
users(id) {
212+
id -> Integer,
213+
name -> Text,
214+
}
210215
}
211-
}
212216

213-
use crate::connection::LoadConnection;
214-
use crate::deserialize::{FromSql, FromSqlRow};
215-
use crate::prelude::*;
216-
use crate::row::{Field, Row};
217-
use crate::sql_types;
217+
use crate::connection::LoadConnection;
218+
use crate::deserialize::{FromSql, FromSqlRow};
219+
use crate::prelude::*;
220+
use crate::row::{Field, Row};
221+
use crate::sql_types;
218222

219-
let conn = &mut crate::test_helpers::connection();
223+
let conn = &mut crate::test_helpers::connection();
220224

221-
crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);")
222-
.execute(conn)
223-
.unwrap();
225+
crate::sql_query("CREATE TABLE users(id INTEGER PRIMARY KEY, name TEXT NOT NULL);")
226+
.execute(conn)
227+
.unwrap();
224228

225-
crate::insert_into(users::table)
226-
.values(vec![
227-
(users::id.eq(1), users::name.eq("Sean")),
228-
(users::id.eq(2), users::name.eq("Tess")),
229-
])
230-
.execute(conn)
231-
.unwrap();
229+
crate::insert_into(users::table)
230+
.values(vec![
231+
(users::id.eq(1), users::name.eq("Sean")),
232+
(users::id.eq(2), users::name.eq("Tess")),
233+
])
234+
.execute(conn)
235+
.unwrap();
232236

233-
let query = users::table.select((users::id, users::name));
237+
let query = users::table.select((users::id, users::name));
234238

235-
let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
239+
let expected = vec![(1, String::from("Sean")), (2, String::from("Tess"))];
236240

237-
let row_iter = conn.load(query).unwrap();
238-
for (row, expected) in row_iter.zip(&expected) {
239-
let row = row.unwrap();
241+
let row_iter = conn.load(query).unwrap();
242+
for (row, expected) in row_iter.zip(&expected) {
243+
let row = row.unwrap();
240244

241-
let deserialized = <(i32, String) as FromSqlRow<
242-
(sql_types::Integer, sql_types::Text),
243-
_,
244-
>>::build_from_row(&row)
245-
.unwrap();
245+
let deserialized = <(i32, String) as FromSqlRow<
246+
(sql_types::Integer, sql_types::Text),
247+
_,
248+
>>::build_from_row(&row)
249+
.unwrap();
246250

247-
assert_eq!(&deserialized, expected);
248-
}
251+
assert_eq!(&deserialized, expected);
252+
}
249253

250-
{
251-
let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
254+
{
255+
let collected_rows = conn.load(query).unwrap().collect::<Vec<_>>();
252256

253-
for (row, expected) in collected_rows.iter().zip(&expected) {
254-
let deserialized = row
255-
.as_ref()
256-
.map(|row| {
257-
<(i32, String) as FromSqlRow<
257+
for (row, expected) in collected_rows.iter().zip(&expected) {
258+
let deserialized = row
259+
.as_ref()
260+
.map(|row| {
261+
<(i32, String) as FromSqlRow<
258262
(sql_types::Integer, sql_types::Text),
259263
_,
260264
>>::build_from_row(row).unwrap()
261-
})
262-
.unwrap();
265+
})
266+
.unwrap();
263267

264-
assert_eq!(&deserialized, expected);
268+
assert_eq!(&deserialized, expected);
269+
}
265270
}
271+
272+
let mut row_iter = conn.load(query).unwrap();
273+
274+
let first_row = row_iter.next().unwrap().unwrap();
275+
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
276+
let first_values = (first_fields.0.value(), first_fields.1.value());
277+
278+
assert!(row_iter.next().unwrap().is_err());
279+
std::mem::drop(first_values);
280+
assert!(row_iter.next().unwrap().is_err());
281+
std::mem::drop(first_fields);
282+
283+
let second_row = row_iter.next().unwrap().unwrap();
284+
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
285+
let second_values = (second_fields.0.value(), second_fields.1.value());
286+
287+
assert!(row_iter.next().unwrap().is_err());
288+
std::mem::drop(second_values);
289+
assert!(row_iter.next().unwrap().is_err());
290+
std::mem::drop(second_fields);
291+
292+
assert!(row_iter.next().is_none());
293+
294+
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
295+
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
296+
297+
let first_values = (first_fields.0.value(), first_fields.1.value());
298+
let second_values = (second_fields.0.value(), second_fields.1.value());
299+
300+
assert_eq!(
301+
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
302+
.unwrap(),
303+
expected[0].0
304+
);
305+
assert_eq!(
306+
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
307+
.unwrap(),
308+
expected[0].1
309+
);
310+
311+
assert_eq!(
312+
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(second_values.0)
313+
.unwrap(),
314+
expected[1].0
315+
);
316+
assert_eq!(
317+
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(second_values.1)
318+
.unwrap(),
319+
expected[1].1
320+
);
321+
322+
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
323+
let first_values = (first_fields.0.value(), first_fields.1.value());
324+
325+
assert_eq!(
326+
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0)
327+
.unwrap(),
328+
expected[0].0
329+
);
330+
assert_eq!(
331+
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1)
332+
.unwrap(),
333+
expected[0].1
334+
);
266335
}
267336

268-
let mut row_iter = conn.load(query).unwrap();
269-
270-
let first_row = row_iter.next().unwrap().unwrap();
271-
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
272-
let first_values = (first_fields.0.value(), first_fields.1.value());
273-
274-
assert!(row_iter.next().unwrap().is_err());
275-
std::mem::drop(first_values);
276-
assert!(row_iter.next().unwrap().is_err());
277-
std::mem::drop(first_fields);
278-
279-
let second_row = row_iter.next().unwrap().unwrap();
280-
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
281-
let second_values = (second_fields.0.value(), second_fields.1.value());
282-
283-
assert!(row_iter.next().unwrap().is_err());
284-
std::mem::drop(second_values);
285-
assert!(row_iter.next().unwrap().is_err());
286-
std::mem::drop(second_fields);
287-
288-
assert!(row_iter.next().is_none());
289-
290-
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
291-
let second_fields = (second_row.get(0).unwrap(), second_row.get(1).unwrap());
292-
293-
let first_values = (first_fields.0.value(), first_fields.1.value());
294-
let second_values = (second_fields.0.value(), second_fields.1.value());
295-
296-
assert_eq!(
297-
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0).unwrap(),
298-
expected[0].0
299-
);
300-
assert_eq!(
301-
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1).unwrap(),
302-
expected[0].1
303-
);
304-
305-
assert_eq!(
306-
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(second_values.0).unwrap(),
307-
expected[1].0
308-
);
309-
assert_eq!(
310-
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(second_values.1).unwrap(),
311-
expected[1].1
312-
);
313-
314-
let first_fields = (first_row.get(0).unwrap(), first_row.get(1).unwrap());
315-
let first_values = (first_fields.0.value(), first_fields.1.value());
316-
317-
assert_eq!(
318-
<i32 as FromSql<sql_types::Integer, Sqlite>>::from_nullable_sql(first_values.0).unwrap(),
319-
expected[0].0
320-
);
321-
assert_eq!(
322-
<String as FromSql<sql_types::Text, Sqlite>>::from_nullable_sql(first_values.1).unwrap(),
323-
expected[0].1
324-
);
337+
#[cfg(feature = "returning_clauses_for_sqlite_3_35")]
338+
crate::define_sql_function! {fn sleep(a: diesel::sql_types::Integer) -> diesel::sql_types::Integer}
339+
340+
#[test]
341+
#[cfg(feature = "returning_clauses_for_sqlite_3_35")]
342+
fn parallel_iter_with_error() {
343+
use crate::connection::Connection;
344+
use crate::connection::LoadConnection;
345+
use crate::connection::SimpleConnection;
346+
use crate::expression_methods::ExpressionMethods;
347+
use crate::SqliteConnection;
348+
use std::sync::{Arc, Barrier};
349+
use std::time::Duration;
350+
351+
let temp_dir = tempfile::tempdir().unwrap();
352+
let db_path = format!("{}/test.db", temp_dir.path().display());
353+
let mut conn1 = SqliteConnection::establish(&db_path).unwrap();
354+
let mut conn2 = SqliteConnection::establish(&db_path).unwrap();
355+
356+
crate::table! {
357+
users {
358+
id -> Integer,
359+
name -> Text,
360+
}
361+
}
362+
363+
conn1
364+
.batch_execute("CREATE TABLE users(id INTEGER NOT NULL PRIMARY KEY, name TEXT)")
365+
.unwrap();
366+
367+
let barrier = Arc::new(Barrier::new(2));
368+
let barrier2 = barrier.clone();
369+
370+
// we unblock the main thread from the sleep function
371+
sleep_utils::register_impl(&mut conn2, move |a: i32| {
372+
barrier.wait();
373+
std::thread::sleep(Duration::from_secs(a as u64));
374+
a
375+
})
376+
.unwrap();
377+
378+
// spawn a background thread that locks the database file
379+
let handle = std::thread::spawn(move || {
380+
use crate::query_dsl::RunQueryDsl;
381+
382+
conn2
383+
.immediate_transaction(|conn| diesel::select(sleep(1)).execute(conn))
384+
.unwrap();
385+
});
386+
barrier2.wait();
387+
388+
// execute some action that also requires a lock
389+
let mut iter = conn1
390+
.load(
391+
diesel::insert_into(users::table)
392+
.values((users::id.eq(1), users::name.eq("John")))
393+
.returning(users::id),
394+
)
395+
.unwrap();
396+
397+
// get the first iterator result, that should return the lock error
398+
let n = iter.next().unwrap();
399+
assert!(n.is_err());
400+
401+
// check that the iterator is now empty
402+
let n = iter.next();
403+
assert!(n.is_none());
404+
405+
// join the background thread
406+
handle.join().unwrap();
407+
}
325408
}

diesel/src/sqlite/connection/statement_iterator.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ impl<'stmt, 'query> Iterator for StatementIterator<'stmt, 'query> {
9292
fn next(&mut self) -> Option<Self::Item> {
9393
use PrivateStatementIterator::{NotStarted, Started};
9494
match &mut self.inner {
95-
NotStarted(ref mut stmt) if stmt.is_some() => {
95+
NotStarted(ref mut stmt @ Some(_)) => {
9696
let mut stmt = stmt
9797
.take()
9898
.expect("It must be there because we checked that above");
@@ -161,12 +161,12 @@ impl<'stmt, 'query> Iterator for StatementIterator<'stmt, 'query> {
161161
)
162162
}
163163
}
164-
NotStarted(_) => unreachable!(
165-
"You've reached an impossible internal state. \
166-
If you ever see this error message please open \
167-
an issue at https://github.com/diesel-rs/diesel \
168-
providing example code how to trigger this error."
169-
),
164+
NotStarted(_s) => {
165+
// we likely got an error while executing the other
166+
// `NotStarted` branch above. In this case we just want to stop
167+
// iterating here
168+
None
169+
}
170170
}
171171
}
172172
}

0 commit comments

Comments
 (0)