Skip to content

Commit 8d28cdb

Browse files
tyranronilslv
andcommitted
Co-authored-by: ilslv <ilya.solovyiov@gmail.com>
1 parent c650713 commit 8d28cdb

10 files changed

+297
-102
lines changed

integration_tests/juniper_tests/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ publish = false
77
[dependencies]
88
derive_more = "0.99"
99
futures = "0.3"
10+
itertools = "0.10"
1011
juniper = { path = "../../juniper" }
1112
juniper_subscriptions = { path = "../../juniper_subscriptions" }
1213

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
//! Checks that long looping chain of fragments doesn't cause a stack overflow.
2+
//!
3+
//! ```graphql
4+
//! # Fragment loop example
5+
//! query {
6+
//! ...a
7+
//! }
8+
//!
9+
//! fragment a on Query {
10+
//! ...b
11+
//! }
12+
//!
13+
//! fragment b on Query {
14+
//! ...a
15+
//! }
16+
//! ```
17+
18+
use std::iter;
19+
20+
use itertools::Itertools as _;
21+
use juniper::{graphql_object, EmptyMutation, EmptySubscription, Variables};
22+
23+
struct Query;
24+
25+
#[graphql_object]
26+
impl Query {
27+
fn dummy() -> bool {
28+
false
29+
}
30+
}
31+
32+
type Schema = juniper::RootNode<'static, Query, EmptyMutation, EmptySubscription>;
33+
34+
#[tokio::test]
35+
async fn test() {
36+
const PERM: &str = "abcefghijk";
37+
const CIRCLE_SIZE: usize = 7500;
38+
39+
let query = iter::once(format!("query {{ ...{PERM} }} "))
40+
.chain(
41+
PERM.chars()
42+
.permutations(PERM.len())
43+
.map(|vec| vec.into_iter().collect::<String>())
44+
.take(CIRCLE_SIZE)
45+
.collect::<Vec<_>>()
46+
.into_iter()
47+
.circular_tuple_windows::<(_, _)>()
48+
.map(|(cur, next)| format!("fragment {cur} on Query {{ ...{next} }} ")),
49+
)
50+
.collect::<String>();
51+
52+
let schema = Schema::new(Query, EmptyMutation::new(), EmptySubscription::new());
53+
let _ = juniper::execute(&query, None, &schema, &Variables::new(), &())
54+
.await
55+
.unwrap_err();
56+
}

integration_tests/juniper_tests/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ mod codegen;
77
#[cfg(test)]
88
mod custom_scalar;
99
#[cfg(test)]
10+
mod cve_2022_31173;
11+
#[cfg(test)]
1012
mod explicit_null;
1113
#[cfg(test)]
1214
mod infallible_as_field_error;

juniper/CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# master
22

3+
- Fix [CVE-2022-31173](https://github.com/graphql-rust/juniper/security/advisories/GHSA-4rx6-g5vg-5f3j).
34
- Fix incorrect error when explicit `null` provided for `null`able list input parameter. ([#1086](https://github.com/graphql-rust/juniper/pull/1086))
45

56
# [[0.15.9] 2022-02-02](https://github.com/graphql-rust/juniper/releases/tag/juniper-v0.15.9)

juniper/src/validation/rules/no_fragment_cycles.rs

+43-25
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,6 @@ use crate::{
77
value::ScalarValue,
88
};
99

10-
pub struct NoFragmentCycles<'a> {
11-
current_fragment: Option<&'a str>,
12-
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
13-
fragment_order: Vec<&'a str>,
14-
}
15-
16-
struct CycleDetector<'a> {
17-
visited: HashSet<&'a str>,
18-
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
19-
path_indices: HashMap<&'a str, usize>,
20-
errors: Vec<RuleError>,
21-
}
22-
2310
pub fn factory<'a>() -> NoFragmentCycles<'a> {
2411
NoFragmentCycles {
2512
current_fragment: None,
@@ -28,6 +15,12 @@ pub fn factory<'a>() -> NoFragmentCycles<'a> {
2815
}
2916
}
3017

18+
pub struct NoFragmentCycles<'a> {
19+
current_fragment: Option<&'a str>,
20+
spreads: HashMap<&'a str, Vec<Spanning<&'a str>>>,
21+
fragment_order: Vec<&'a str>,
22+
}
23+
3124
impl<'a, S> Visitor<'a, S> for NoFragmentCycles<'a>
3225
where
3326
S: ScalarValue,
@@ -38,14 +31,12 @@ where
3831
let mut detector = CycleDetector {
3932
visited: HashSet::new(),
4033
spreads: &self.spreads,
41-
path_indices: HashMap::new(),
4234
errors: Vec::new(),
4335
};
4436

4537
for frag in &self.fragment_order {
4638
if !detector.visited.contains(frag) {
47-
let mut path = Vec::new();
48-
detector.detect_from(frag, &mut path);
39+
detector.detect_from(frag);
4940
}
5041
}
5142

@@ -91,19 +82,46 @@ where
9182
}
9283
}
9384

85+
type CycleDetectorState<'a> = (&'a str, Vec<&'a Spanning<&'a str>>, HashMap<&'a str, usize>);
86+
87+
struct CycleDetector<'a> {
88+
visited: HashSet<&'a str>,
89+
spreads: &'a HashMap<&'a str, Vec<Spanning<&'a str>>>,
90+
errors: Vec<RuleError>,
91+
}
92+
9493
impl<'a> CycleDetector<'a> {
95-
fn detect_from(&mut self, from: &'a str, path: &mut Vec<&'a Spanning<&'a str>>) {
94+
fn detect_from(&mut self, from: &'a str) {
95+
let mut to_visit = Vec::new();
96+
to_visit.push((from, Vec::new(), HashMap::new()));
97+
98+
while let Some((from, path, path_indices)) = to_visit.pop() {
99+
to_visit.extend(self.detect_from_inner(from, path, path_indices));
100+
}
101+
}
102+
103+
/// This function should be called only inside [`Self::detect_from()`], as
104+
/// it's a recursive function using heap instead of a stack. So, instead of
105+
/// the recursive call, we return a [`Vec`] that is visited inside
106+
/// [`Self::detect_from()`].
107+
fn detect_from_inner(
108+
&mut self,
109+
from: &'a str,
110+
path: Vec<&'a Spanning<&'a str>>,
111+
mut path_indices: HashMap<&'a str, usize>,
112+
) -> Vec<CycleDetectorState<'a>> {
96113
self.visited.insert(from);
97114

98115
if !self.spreads.contains_key(from) {
99-
return;
116+
return Vec::new();
100117
}
101118

102-
self.path_indices.insert(from, path.len());
119+
path_indices.insert(from, path.len());
103120

121+
let mut to_visit = Vec::new();
104122
for node in &self.spreads[from] {
105-
let name = &node.item;
106-
let index = self.path_indices.get(name).cloned();
123+
let name = node.item;
124+
let index = path_indices.get(name).cloned();
107125

108126
if let Some(index) = index {
109127
let err_pos = if index < path.len() {
@@ -114,14 +132,14 @@ impl<'a> CycleDetector<'a> {
114132

115133
self.errors
116134
.push(RuleError::new(&error_message(name), &[err_pos.start]));
117-
} else if !self.visited.contains(name) {
135+
} else {
136+
let mut path = path.clone();
118137
path.push(node);
119-
self.detect_from(name, path);
120-
path.pop();
138+
to_visit.push((name, path, path_indices.clone()));
121139
}
122140
}
123141

124-
self.path_indices.remove(from);
142+
to_visit
125143
}
126144
}
127145

juniper/src/validation/rules/no_undefined_variables.rs

+35-13
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@ pub enum Scope<'a> {
1212
Fragment(&'a str),
1313
}
1414

15-
pub struct NoUndefinedVariables<'a> {
16-
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
17-
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
18-
current_scope: Option<Scope<'a>>,
19-
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
20-
}
21-
2215
pub fn factory<'a>() -> NoUndefinedVariables<'a> {
2316
NoUndefinedVariables {
2417
defined_variables: HashMap::new(),
@@ -28,6 +21,13 @@ pub fn factory<'a>() -> NoUndefinedVariables<'a> {
2821
}
2922
}
3023

24+
pub struct NoUndefinedVariables<'a> {
25+
defined_variables: HashMap<Option<&'a str>, (SourcePosition, HashSet<&'a str>)>,
26+
used_variables: HashMap<Scope<'a>, Vec<Spanning<&'a str>>>,
27+
current_scope: Option<Scope<'a>>,
28+
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
29+
}
30+
3131
impl<'a> NoUndefinedVariables<'a> {
3232
fn find_undef_vars(
3333
&'a self,
@@ -36,8 +36,34 @@ impl<'a> NoUndefinedVariables<'a> {
3636
unused: &mut Vec<&'a Spanning<&'a str>>,
3737
visited: &mut HashSet<Scope<'a>>,
3838
) {
39+
let mut to_visit = Vec::new();
40+
if let Some(spreads) = self.find_undef_vars_inner(scope, defined, unused, visited) {
41+
to_visit.push(spreads);
42+
}
43+
while let Some(spreads) = to_visit.pop() {
44+
for spread in spreads {
45+
if let Some(spreads) =
46+
self.find_undef_vars_inner(&Scope::Fragment(spread), defined, unused, visited)
47+
{
48+
to_visit.push(spreads);
49+
}
50+
}
51+
}
52+
}
53+
54+
/// This function should be called only inside [`Self::find_undef_vars()`],
55+
/// as it's a recursive function using heap instead of a stack. So, instead
56+
/// of the recursive call, we return a [`Vec`] that is visited inside
57+
/// [`Self::find_undef_vars()`].
58+
fn find_undef_vars_inner(
59+
&'a self,
60+
scope: &Scope<'a>,
61+
defined: &HashSet<&'a str>,
62+
unused: &mut Vec<&'a Spanning<&'a str>>,
63+
visited: &mut HashSet<Scope<'a>>,
64+
) -> Option<&'a Vec<&'a str>> {
3965
if visited.contains(scope) {
40-
return;
66+
return None;
4167
}
4268

4369
visited.insert(scope.clone());
@@ -50,11 +76,7 @@ impl<'a> NoUndefinedVariables<'a> {
5076
}
5177
}
5278

53-
if let Some(spreads) = self.spreads.get(scope) {
54-
for spread in spreads {
55-
self.find_undef_vars(&Scope::Fragment(spread), defined, unused, visited);
56-
}
57-
}
79+
self.spreads.get(scope)
5880
}
5981
}
6082

juniper/src/validation/rules/no_unused_fragments.rs

+32-17
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,12 @@ use crate::{
77
value::ScalarValue,
88
};
99

10-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1111
pub enum Scope<'a> {
1212
Operation(Option<&'a str>),
1313
Fragment(&'a str),
1414
}
1515

16-
pub struct NoUnusedFragments<'a> {
17-
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
18-
defined_fragments: HashSet<Spanning<&'a str>>,
19-
current_scope: Option<Scope<'a>>,
20-
}
21-
2216
pub fn factory<'a>() -> NoUnusedFragments<'a> {
2317
NoUnusedFragments {
2418
spreads: HashMap::new(),
@@ -27,21 +21,42 @@ pub fn factory<'a>() -> NoUnusedFragments<'a> {
2721
}
2822
}
2923

24+
pub struct NoUnusedFragments<'a> {
25+
spreads: HashMap<Scope<'a>, Vec<&'a str>>,
26+
defined_fragments: HashSet<Spanning<&'a str>>,
27+
current_scope: Option<Scope<'a>>,
28+
}
29+
3030
impl<'a> NoUnusedFragments<'a> {
31-
fn find_reachable_fragments(&self, from: &Scope<'a>, result: &mut HashSet<&'a str>) {
32-
if let Scope::Fragment(name) = *from {
31+
fn find_reachable_fragments(&'a self, from: Scope<'a>, result: &mut HashSet<&'a str>) {
32+
let mut to_visit = Vec::new();
33+
to_visit.push(from);
34+
35+
while let Some(from) = to_visit.pop() {
36+
if let Some(next) = self.find_reachable_fragments_inner(from, result) {
37+
to_visit.extend(next.iter().map(|s| Scope::Fragment(s)));
38+
}
39+
}
40+
}
41+
42+
/// This function should be called only inside
43+
/// [`Self::find_reachable_fragments()`], as it's a recursive function using
44+
/// heap instead of a stack. So, instead of the recursive call, we return a
45+
/// [`Vec`] that is visited inside [`Self::find_reachable_fragments()`].
46+
fn find_reachable_fragments_inner(
47+
&'a self,
48+
from: Scope<'a>,
49+
result: &mut HashSet<&'a str>,
50+
) -> Option<&'a Vec<&'a str>> {
51+
if let Scope::Fragment(name) = from {
3352
if result.contains(name) {
34-
return;
53+
return None;
3554
} else {
3655
result.insert(name);
3756
}
3857
}
3958

40-
if let Some(spreads) = self.spreads.get(from) {
41-
for spread in spreads {
42-
self.find_reachable_fragments(&Scope::Fragment(spread), result)
43-
}
44-
}
59+
self.spreads.get(&from)
4560
}
4661
}
4762

@@ -59,7 +74,7 @@ where
5974
}) = *def
6075
{
6176
let op_name = name.as_ref().map(|s| s.item);
62-
self.find_reachable_fragments(&Scope::Operation(op_name), &mut reachable);
77+
self.find_reachable_fragments(Scope::Operation(op_name), &mut reachable);
6378
}
6479
}
6580

@@ -96,7 +111,7 @@ where
96111
) {
97112
if let Some(ref scope) = self.current_scope {
98113
self.spreads
99-
.entry(scope.clone())
114+
.entry(*scope)
100115
.or_insert_with(Vec::new)
101116
.push(spread.item.name.item);
102117
}

0 commit comments

Comments
 (0)