diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 1ea68e386..f1f871f25 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -102,6 +102,16 @@ functions: ${PREPARE_SHELL} .evergreen/run-tests-serde.sh + "run decimal128 tests": + - command: shell.exec + type: test + params: + shell: bash + working_dir: "src" + script: | + ${PREPARE_SHELL} + .evergreen/run-tests-decimal128.sh + "compile only": - command: shell.exec type: test @@ -164,6 +174,10 @@ tasks: commands: - func: "run serde tests" + - name: "test-decimal128" + commands: + - func: "run decimal128 tests" + - name: "compile-only" commands: - func: "compile only" @@ -180,9 +194,9 @@ axes: - id: "extra-rust-versions" values: - id: "min" - display_name: "1.43 (minimum supported version)" + display_name: "1.48 (minimum supported version)" variables: - RUST_VERSION: "1.43.1" + RUST_VERSION: "1.48.0" - id: "nightly" display_name: "nightly" variables: @@ -198,6 +212,7 @@ buildvariants: - name: "test" - name: "test-u2i" - name: "test-serde" + - name: "test-decimal128" - matrix_name: "compile only" matrix_spec: diff --git a/.evergreen/run-tests-decimal128.sh b/.evergreen/run-tests-decimal128.sh new file mode 100755 index 000000000..10a30b014 --- /dev/null +++ b/.evergreen/run-tests-decimal128.sh @@ -0,0 +1,6 @@ +#!/bin/sh + +set -o errexit + +. ~/.cargo/env +RUST_BACKTRACE=1 cargo test --features decimal128 diff --git a/Cargo.toml b/Cargo.toml index a359f48bf..25a6a02b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ decimal = { version = "2.1.0", default_features = false, optional = true } base64 = "0.13.0" lazy_static = "1.4.0" uuid = "0.8.1" +serde_bytes = "0.11.5" [dev-dependencies] assert_matches = "1.2" diff --git a/README.md b/README.md index dfbfba6d6..5acce2278 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,8 @@ This crate works with Cargo and can be found on bson = "2.0.0-beta.2" ``` +This crate requires Rust 1.48+. + ## Overview of BSON Format BSON, short for Binary JSON, is a binary-encoded serialization of JSON-like documents. diff --git a/serde-tests/Cargo.toml b/serde-tests/Cargo.toml index 8244be081..62be8bd2e 100644 --- a/serde-tests/Cargo.toml +++ b/serde-tests/Cargo.toml @@ -5,8 +5,9 @@ authors = ["Kevin Yeh "] edition = "2018" [dependencies] -bson = { path = ".." } +bson = { path = "..", features = ["decimal128"] } serde = { version = "1.0", features = ["derive"] } +pretty_assertions = "0.6.1" [lib] name = "serde_tests" diff --git a/serde-tests/test.rs b/serde-tests/test.rs index 938ffebc4..d8e2df01f 100644 --- a/serde-tests/test.rs +++ b/serde-tests/test.rs @@ -1,77 +1,107 @@ #![allow(clippy::cognitive_complexity)] #![allow(clippy::vec_init_then_push)] -use serde::{self, de::Unexpected, Deserialize, Serialize}; -use std::collections::{BTreeMap, HashSet}; - -use bson::{Bson, Deserializer, Serializer}; - -macro_rules! bson { - ([]) => {{ bson::Bson::Array(Vec::new()) }}; - - ([$($val:tt),*]) => {{ - let mut array = Vec::new(); - - $( - array.push(bson!($val)); - )* - - bson::Bson::Array(array) - }}; - - ([$val:expr]) => {{ - bson::Bson::Array(vec!(::std::convert::From::from($val))) - }}; - - ({ $($k:expr => $v:tt),* }) => {{ - bdoc! { - $( - $k => $v - ),* - } - }}; +use pretty_assertions::assert_eq; +use serde::{ + self, + de::{DeserializeOwned, Unexpected}, + Deserialize, + Serialize, +}; + +use std::{ + borrow::Cow, + collections::{BTreeMap, HashSet}, +}; + +use bson::{ + doc, + oid::ObjectId, + spec::BinarySubtype, + Binary, + Bson, + DateTime, + Decimal128, + Deserializer, + Document, + JavaScriptCodeWithScope, + Regex, + Timestamp, +}; + +/// Verifies the following: +/// - round trip `expected_value` through `Document`: +/// - serializing the `expected_value` to a `Document` matches the `expected_doc` +/// - deserializing from the serialized document produces `expected_value` +/// - round trip through raw BSON: +/// - deserializing a `T` from the raw BSON version of `expected_doc` produces `expected_value` +/// - deserializing a `Document` from the raw BSON version of `expected_doc` produces +/// `expected_doc` +fn run_test(expected_value: &T, expected_doc: &Document, description: &str) +where + T: Serialize + DeserializeOwned + PartialEq + std::fmt::Debug, +{ + let mut expected_bytes = Vec::new(); + expected_doc + .to_writer(&mut expected_bytes) + .expect(description); + + let serialized_doc = bson::to_document(&expected_value).expect(description); + assert_eq!(&serialized_doc, expected_doc, "{}", description); + assert_eq!( + expected_value, + &bson::from_document::(serialized_doc).expect(description), + "{}", + description + ); - ($val:expr) => {{ - ::std::convert::From::from($val) - }}; + assert_eq!( + &bson::from_reader::<_, T>(expected_bytes.as_slice()).expect(description), + expected_value, + "{}", + description + ); + assert_eq!( + &bson::from_reader::<_, Document>(expected_bytes.as_slice()).expect(description), + expected_doc, + "{}", + description + ); } -macro_rules! bdoc { - () => {{ Bson::Document(bson::Document::new()) }}; - - ( $($key:expr => $val:tt),* ) => {{ - let mut document = bson::Document::new(); +/// Verifies the following: +/// - deserializing a `T` from `expected_doc` produces `expected_value` +/// - deserializing a `T` from the raw BSON version of `expected_doc` produces `expected_value` +/// - deserializing a `Document` from the raw BSON version of `expected_doc` produces `expected_doc` +fn run_deserialize_test(expected_value: &T, expected_doc: &Document, description: &str) +where + T: DeserializeOwned + PartialEq + std::fmt::Debug, +{ + let mut expected_bytes = Vec::new(); + expected_doc + .to_writer(&mut expected_bytes) + .expect(description); - $( - document.insert::<_, Bson>($key.to_owned(), bson!($val)); - )* - - Bson::Document(document) - }}; -} - -macro_rules! t { - ($e:expr) => { - match $e { - Ok(t) => t, - Err(e) => panic!("Failed with {:?}", e), - } - }; + assert_eq!( + &bson::from_document::(expected_doc.clone()).expect(description), + expected_value, + "{}", + description + ); + assert_eq!( + &bson::from_reader::<_, T>(expected_bytes.as_slice()).expect(description), + expected_value, + "{}", + description + ); + assert_eq!( + &bson::from_reader::<_, Document>(expected_bytes.as_slice()).expect(description), + expected_doc, + "{}", + description + ); } -macro_rules! serialize( ($t:expr) => ({ - let e = Serializer::new(); - match $t.serialize(e) { - Ok(b) => b, - Err(e) => panic!("Failed to serialize: {}", e), - } -}) ); - -macro_rules! deserialize( ($t:expr) => ({ - let d = Deserializer::new($t); - t!(Deserialize::deserialize(d)) -}) ); - #[test] fn smoke() { #[derive(Serialize, Deserialize, PartialEq, Debug)] @@ -80,8 +110,9 @@ fn smoke() { } let v = Foo { a: 2 }; - assert_eq!(serialize!(v), bdoc! {"a" => (2_i64)}); - assert_eq!(v, deserialize!(serialize!(v))); + let expected = doc! { "a": 2_i64 }; + + run_test(&v, &expected, "smoke"); } #[test] @@ -92,12 +123,12 @@ fn smoke_under() { } let v = Foo { a_b: 2 }; - assert_eq!(serialize!(v), bdoc! { "a_b" => (2_i64) }); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { "a_b": 2_i64 }; + run_test(&v, &doc, "smoke under"); let mut m = BTreeMap::new(); m.insert("a_b".to_string(), 2_i64); - assert_eq!(v, deserialize!(serialize!(m))); + run_test(&m, &doc, "smoke under BTreeMap"); } #[test] @@ -118,16 +149,13 @@ fn nested() { a: "test".to_string(), }, }; - assert_eq!( - serialize!(v), - bdoc! { - "a" => (2_i64), - "b" => { - "a" => "test" - } + let doc = doc! { + "a": 2_i64, + "b": { + "a": "test" } - ); - assert_eq!(v, deserialize!(serialize!(v))); + }; + run_test(&v, &doc, "nested"); } #[test] @@ -151,12 +179,13 @@ fn application_deserialize_error() { let d_bad1 = Deserializer::new(Bson::String("not an isize".to_string())); let d_bad2 = Deserializer::new(Bson::Int64(11)); - assert_eq!(Range10(5), t!(Deserialize::deserialize(d_good))); + assert_eq!( + Range10(5), + Deserialize::deserialize(d_good).expect("deserialization should succeed") + ); - let err1: Result = Deserialize::deserialize(d_bad1); - assert!(err1.is_err()); - let err2: Result = Deserialize::deserialize(d_bad2); - assert!(err2.is_err()); + Range10::deserialize(d_bad1).expect_err("deserialization from string should fail"); + Range10::deserialize(d_bad2).expect_err("deserialization from 11 should fail"); } #[test] @@ -169,13 +198,10 @@ fn array() { let v = Foo { a: vec![1, 2, 3, 4], }; - assert_eq!( - serialize!(v), - bdoc! { - "a" => [1, 2, 3, 4] - } - ); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "a": [1, 2, 3, 4], + }; + run_test(&v, &doc, "array"); } #[test] @@ -186,13 +212,10 @@ fn tuple() { } let v = Foo { a: (1, 2, 3, 4) }; - assert_eq!( - serialize!(v), - bdoc! { - "a" => [1, 2, 3, 4] - } - ); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "a": [1, 2, 3, 4], + }; + run_test(&v, &doc, "tuple"); } #[test] @@ -221,23 +244,20 @@ fn inner_structs_with_options() { b: 1.0, }, }; - assert_eq!( - serialize!(v), - bdoc! { - "a" => { - "a" => (Bson::Null), - "b" => { - "a" => "foo", - "b" => (4.5) - } - }, - "b" => { - "a" => "bar", - "b" => (1.0) + let doc = doc! { + "a": { + "a": Bson::Null, + "b": { + "a": "foo", + "b": 4.5, } + }, + "b": { + "a": "bar", + "b": 1.0, } - ); - assert_eq!(v, deserialize!(serialize!(v))); + }; + run_test(&v, &doc, "inner_structs_with_options"); } #[test] @@ -267,22 +287,19 @@ fn inner_structs_with_skippable_options() { b: 1.0, }, }; - assert_eq!( - serialize!(v), - bdoc! { - "a" => { - "b" => { - "a" => "foo", - "b" => (4.5) - } - }, - "b" => { - "a" => "bar", - "b" => (1.0) + let doc = doc! { + "a" : { + "b": { + "a": "foo", + "b": 4.5 } + }, + "b": { + "a": "bar", + "b": 1.0 } - ); - assert_eq!(v, deserialize!(serialize!(v))); + }; + run_test(&v, &doc, "inner_structs_with_skippable_options"); } #[test] @@ -306,17 +323,14 @@ fn hashmap() { s }, }; - assert_eq!( - serialize!(v), - bdoc! { - "map" => { - "bar" => 4, - "foo" => 10 - }, - "set" => ["a"] - } - ); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "map": { + "bar": 4, + "foo": 10 + }, + "set": ["a"] + }; + run_test(&v, &doc, "hashmap"); } #[test] @@ -331,13 +345,10 @@ fn tuple_struct() { let v = Bar { whee: Foo(1, "foo".to_string(), 4.5), }; - assert_eq!( - serialize!(v), - bdoc! { - "whee" => [1, "foo", (4.5)] - } - ); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "whee": [1, "foo", 4.5], + }; + run_test(&v, &doc, "tuple_struct"); } #[test] @@ -354,13 +365,10 @@ fn table_array() { let v = Foo { a: vec![Bar { a: 1 }, Bar { a: 2 }], }; - assert_eq!( - serialize!(v), - bdoc! { - "a" => [{"a" => 1}, {"a" => 2}] - } - ); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "a": [{ "a": 1 }, { "a": 2 }] + }; + run_test(&v, &doc, "table_array"); } #[test] @@ -370,11 +378,18 @@ fn type_conversion() { bar: i32, } - let d = Deserializer::new(bdoc! { - "bar" => 1 - }); - let a: Result = Deserialize::deserialize(d); - assert_eq!(a.unwrap(), Foo { bar: 1 }); + let v = Foo { bar: 1 }; + let doc = doc! { + "bar": 1_i64 + }; + let deserialized: Foo = bson::from_document(doc.clone()).unwrap(); + assert_eq!(deserialized, v); + + let mut bytes = Vec::new(); + doc.to_writer(&mut bytes).unwrap(); + + let bson_deserialized: Foo = bson::from_reader(bytes.as_slice()).unwrap(); + assert_eq!(bson_deserialized, v); } #[test] @@ -384,10 +399,14 @@ fn missing_errors() { bar: i32, } - let d = Deserializer::new(bdoc! {}); - let a: Result = Deserialize::deserialize(d); + let doc = doc! {}; - assert!(a.is_err()); + bson::from_document::(doc.clone()).unwrap_err(); + + let mut bytes = Vec::new(); + doc.to_writer(&mut bytes).unwrap(); + + bson::from_reader::<_, Foo>(bytes.as_slice()).unwrap_err(); } #[test] @@ -413,51 +432,54 @@ fn parse_enum() { } let v = Foo { a: E::Empty }; - assert_eq!(serialize!(v), bdoc! { "a" => "Empty" }); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { "a": "Empty" }; + run_test(&v, &doc, "parse_enum: Empty"); let v = Foo { a: E::Bar(10) }; - assert_eq!(serialize!(v), bdoc! { "a" => { "Bar" => 10 } }); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { "a": { "Bar": 10 } }; + run_test(&v, &doc, "parse_enum: newtype variant int"); let v = Foo { a: E::Baz(10.2) }; - assert_eq!(serialize!(v), bdoc! { "a" => { "Baz" => 10.2 } }); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { "a": { "Baz": 10.2 } }; + run_test(&v, &doc, "parse_enum: newtype variant double"); let v = Foo { a: E::Pair(12, 42) }; - assert_eq!(serialize!(v), bdoc! { "a" => { "Pair" => [ 12, 42] } }); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { "a": { "Pair": [12, 42] } }; + run_test(&v, &doc, "parse_enum: tuple variant"); let v = Foo { a: E::Last(Foo2 { test: "test".to_string(), }), }; - assert_eq!( - serialize!(v), - bdoc! { "a" => { "Last" => { "test" => "test" } } } - ); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "a": { "Last": { "test": "test" } } + }; + run_test(&v, &doc, "parse_enum: newtype variant struct"); let v = Foo { a: E::Vector(vec![12, 42]), }; - assert_eq!(serialize!(v), bdoc! { "a" => { "Vector" => [ 12, 42 ] } }); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "a": { "Vector": [12, 42] } + }; + run_test(&v, &doc, "parse_enum: newtype variant vector"); let v = Foo { a: E::Named { a: 12 }, }; - assert_eq!(serialize!(v), bdoc! { "a" => { "Named" => { "a" => 12 } } }); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "a": { "Named": { "a": 12 } } + }; + run_test(&v, &doc, "parse_enum: struct variant"); + let v = Foo { a: E::MultiNamed { a: 12, b: 42 }, }; - assert_eq!( - serialize!(v), - bdoc! { "a" => { "MultiNamed" => { "a" => 12, "b" => 42 } } } - ); - assert_eq!(v, deserialize!(serialize!(v))); + let doc = doc! { + "a": { "MultiNamed": { "a": 12, "b": 42 } } + }; + run_test(&v, &doc, "parse_enum: struct variant multiple fields"); } #[test] @@ -468,12 +490,12 @@ fn unused_fields() { } let v = Foo { a: 2 }; - let d = Deserializer::new(bdoc! { - "a" => 2, - "b" => 5 - }); + let doc = doc! { + "a": 2, + "b": 5, + }; - assert_eq!(v, t!(Deserialize::deserialize(d))); + run_deserialize_test(&v, &doc, "unused_fields"); } #[test] @@ -488,14 +510,14 @@ fn unused_fields2() { } let v = Foo { a: Bar { a: 2 } }; - let d = Deserializer::new(bdoc! { - "a" => { - "a" => 2, - "b" => 5 + let doc = doc! { + "a": { + "a": 2, + "b": 5 } - }); + }; - assert_eq!(v, t!(Deserialize::deserialize(d))); + run_deserialize_test(&v, &doc, "unused_fields2"); } #[test] @@ -510,12 +532,12 @@ fn unused_fields3() { } let v = Foo { a: Bar { a: 2 } }; - let d = Deserializer::new(bdoc! { - "a" => { - "a" => 2 + let doc = doc! { + "a": { + "a": 2 } - }); - assert_eq!(v, t!(Deserialize::deserialize(d))); + }; + run_deserialize_test(&v, &doc, "unused_fields3"); } #[test] @@ -528,12 +550,12 @@ fn unused_fields4() { let mut map = BTreeMap::new(); map.insert("a".to_owned(), "foo".to_owned()); let v = Foo { a: map }; - let d = Deserializer::new(bdoc! { - "a" => { - "a" => "foo" + let doc = doc! { + "a": { + "a": "foo" } - }); - assert_eq!(v, t!(Deserialize::deserialize(d))); + }; + run_deserialize_test(&v, &doc, "unused_fields4"); } #[test] @@ -546,10 +568,10 @@ fn unused_fields5() { let v = Foo { a: vec!["a".to_string()], }; - let d = Deserializer::new(bdoc! { - "a" => ["a"] - }); - assert_eq!(v, t!(Deserialize::deserialize(d))); + let doc = doc! { + "a": ["a"] + }; + run_deserialize_test(&v, &doc, "unusued_fields5"); } #[test] @@ -560,10 +582,10 @@ fn unused_fields6() { } let v = Foo { a: Some(vec![]) }; - let d = Deserializer::new(bdoc! { - "a" => [] - }); - assert_eq!(v, t!(Deserialize::deserialize(d))); + let doc = doc! { + "a": [] + }; + run_deserialize_test(&v, &doc, "unused_fieds6"); } #[test] @@ -580,10 +602,10 @@ fn unused_fields7() { let v = Foo { a: vec![Bar { a: 1 }], }; - let d = Deserializer::new(bdoc! { - "a" => [{"a" => 1, "b" => 2}] - }); - assert_eq!(v, t!(Deserialize::deserialize(d))); + let doc = doc! { + "a": [{"a": 1, "b": 2}] + }; + run_deserialize_test(&v, &doc, "unused_fields7"); } #[test] @@ -594,15 +616,19 @@ fn unused_fields_deny() { a: i32, } - let d = Deserializer::new(bdoc! { - "a" => 1, - "b" => 2 - }); - Foo::deserialize(d).expect_err("extra fields should cause failure"); + let doc = doc! { + "a": 1, + "b": 2, + }; + bson::from_document::(doc.clone()).expect_err("extra fields should cause failure"); + + let mut bytes = Vec::new(); + doc.to_writer(&mut bytes).unwrap(); + bson::from_reader::<_, Foo>(bytes.as_slice()).expect_err("extra fields should cause failure"); } #[test] -fn empty_arrays() { +fn default_array() { #[derive(Serialize, Deserialize, PartialEq, Debug)] struct Foo { #[serde(default)] @@ -612,12 +638,12 @@ fn empty_arrays() { struct Bar; let v = Foo { a: vec![] }; - let d = Deserializer::new(bdoc! {}); - assert_eq!(v, t!(Deserialize::deserialize(d))); + let doc = doc! {}; + run_deserialize_test(&v, &doc, "default_array"); } #[test] -fn empty_arrays2() { +fn null_array() { #[derive(Serialize, Deserialize, PartialEq, Debug)] struct Foo { a: Option>, @@ -626,12 +652,202 @@ fn empty_arrays2() { struct Bar; let v = Foo { a: None }; - let d = Deserializer::new(bdoc! {}); - assert_eq!(v, t!(Deserialize::deserialize(d))); + let doc = doc! {}; + run_deserialize_test(&v, &doc, "null_array"); +} + +#[test] +fn empty_array() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Foo { + a: Option>, + } + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Bar; let v = Foo { a: Some(vec![]) }; - let d = Deserializer::new(bdoc! { - "a" => [] - }); - assert_eq!(v, t!(Deserialize::deserialize(d))); + let doc = doc! { + "a": [] + }; + run_deserialize_test(&v, &doc, "empty_array"); +} + +#[test] +fn all_types() { + #[derive(Debug, Deserialize, Serialize, PartialEq)] + struct Bar { + a: i32, + b: i32, + } + + #[derive(Debug, Deserialize, Serialize, PartialEq)] + struct Foo { + x: i32, + y: i64, + s: String, + array: Vec, + bson: Bson, + oid: ObjectId, + null: Option<()>, + subdoc: Document, + b: bool, + d: f64, + binary: Binary, + binary_old: Binary, + binary_other: Binary, + date: DateTime, + regex: Regex, + ts: Timestamp, + i: Bar, + undefined: Bson, + code: Bson, + code_w_scope: JavaScriptCodeWithScope, + decimal: Decimal128, + symbol: Bson, + min_key: Bson, + max_key: Bson, + } + + let binary = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::Generic, + }; + let binary_old = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::BinaryOld, + }; + let binary_other = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::UserDefined(0x81), + }; + let date = DateTime::now(); + let regex = Regex { + pattern: "hello".to_string(), + options: "x".to_string(), + }; + let timestamp = Timestamp { + time: 123, + increment: 456, + }; + let code = Bson::JavaScriptCode("console.log(1)".to_string()); + let code_w_scope = JavaScriptCodeWithScope { + code: "console.log(a)".to_string(), + scope: doc! { "a": 1 }, + }; + let oid = ObjectId::new(); + let subdoc = doc! { "k": true, "b": { "hello": "world" } }; + + let doc = doc! { + "x": 1, + "y": 2_i64, + "s": "oke", + "array": [ true, "oke", { "12": 24 } ], + "bson": 1234.5, + "oid": oid, + "null": Bson::Null, + "subdoc": subdoc.clone(), + "b": true, + "d": 12.5, + "binary": binary.clone(), + "binary_old": binary_old.clone(), + "binary_other": binary_other.clone(), + "date": date, + "regex": regex.clone(), + "ts": timestamp, + "i": { "a": 300, "b": 12345 }, + "undefined": Bson::Undefined, + "code": code.clone(), + "code_w_scope": code_w_scope.clone(), + "decimal": Bson::Decimal128(Decimal128::from_i32(5)), + "symbol": Bson::Symbol("ok".to_string()), + "min_key": Bson::MinKey, + "max_key": Bson::MaxKey, + }; + + let v = Foo { + x: 1, + y: 2, + s: "oke".to_string(), + array: vec![ + Bson::Boolean(true), + Bson::String("oke".to_string()), + Bson::Document(doc! { "12": 24 }), + ], + bson: Bson::Double(1234.5), + oid, + null: None, + subdoc, + b: true, + d: 12.5, + binary, + binary_old, + binary_other, + date, + regex, + ts: timestamp, + i: Bar { a: 300, b: 12345 }, + undefined: Bson::Undefined, + code, + code_w_scope, + decimal: Decimal128::from_i32(5), + symbol: Bson::Symbol("ok".to_string()), + min_key: Bson::MinKey, + max_key: Bson::MaxKey, + }; + + run_test(&v, &doc, "all types"); +} + +#[test] +fn borrowed() { + #[derive(Debug, Deserialize, PartialEq)] + struct Foo<'a> { + s: &'a str, + binary: &'a [u8], + doc: Inner<'a>, + #[serde(borrow)] + cow: Cow<'a, str>, + #[serde(borrow)] + array: Vec<&'a str>, + } + + #[derive(Debug, Deserialize, PartialEq)] + struct Inner<'a> { + string: &'a str, + } + + let binary = Binary { + bytes: vec![36, 36, 36], + subtype: BinarySubtype::Generic, + }; + + let doc = doc! { + "s": "borrowed string", + "binary": binary.clone(), + "doc": { + "string": "another borrowed string", + }, + "cow": "cow", + "array": ["borrowed string"], + }; + let mut bson = Vec::new(); + doc.to_writer(&mut bson).unwrap(); + + let s = "borrowed string".to_string(); + let ss = "another borrowed string".to_string(); + let cow = "cow".to_string(); + let inner = Inner { + string: ss.as_str(), + }; + let v = Foo { + s: s.as_str(), + binary: binary.bytes.as_slice(), + doc: inner, + cow: Cow::Borrowed(cow.as_str()), + array: vec![s.as_str()], + }; + + let deserialized: Foo = + bson::from_slice(bson.as_slice()).expect("deserialization should succeed"); + assert_eq!(deserialized, v); } diff --git a/src/bson.rs b/src/bson.rs index 7de8a34db..03b8bee09 100644 --- a/src/bson.rs +++ b/src/bson.rs @@ -21,7 +21,10 @@ //! BSON definition -use std::fmt::{self, Debug, Display, Formatter}; +use std::{ + convert::TryFrom, + fmt::{self, Debug, Display, Formatter}, +}; use chrono::Datelike; use serde_json::{json, Value}; @@ -182,7 +185,7 @@ impl Debug for Bson { impl From for Bson { fn from(a: f32) -> Bson { - Bson::Double(a as f64) + Bson::Double(a.into()) } } @@ -297,7 +300,11 @@ impl From for Bson { impl From for Bson { fn from(a: u32) -> Bson { - Bson::Int32(a as i32) + if let Ok(i) = i32::try_from(a) { + Bson::Int32(i) + } else { + Bson::Int64(a.into()) + } } } @@ -529,8 +536,8 @@ impl Bson { /// This function mainly used for [extended JSON format](https://docs.mongodb.com/manual/reference/mongodb-extended-json/). // TODO RUST-426: Investigate either removing this from the serde implementation or unifying // with the extended JSON implementation. - pub(crate) fn to_extended_document(&self) -> Document { - match *self { + pub(crate) fn into_extended_document(self) -> Document { + match self { Bson::RegularExpression(Regex { ref pattern, ref options, @@ -542,23 +549,20 @@ impl Bson { doc! { "$regularExpression": { - "pattern": pattern.clone(), + "pattern": pattern, "options": options, } } } Bson::JavaScriptCode(ref code) => { doc! { - "$code": code.clone(), + "$code": code, } } - Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { - ref code, - ref scope, - }) => { + Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { code, scope }) => { doc! { - "$code": code.clone(), - "$scope": scope.clone(), + "$code": code, + "$scope": scope, } } Bson::Timestamp(Timestamp { time, increment }) => { @@ -583,7 +587,7 @@ impl Bson { "$oid": v.to_string(), } } - Bson::DateTime(v) if v.timestamp_millis() >= 0 && v.to_chrono().year() <= 99999 => { + Bson::DateTime(v) if v.timestamp_millis() >= 0 && v.to_chrono().year() <= 9999 => { doc! { "$date": v.to_rfc3339(), } diff --git a/src/de/mod.rs b/src/de/mod.rs index aa1479afd..8a288f00f 100644 --- a/src/de/mod.rs +++ b/src/de/mod.rs @@ -22,6 +22,7 @@ //! Deserializer mod error; +mod raw; mod serde; pub use self::{ @@ -33,7 +34,8 @@ use std::io::Read; use crate::{ bson::{Array, Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex, Timestamp}, - oid, + oid::{self, ObjectId}, + ser::write_i32, spec::{self, BinarySubtype}, Decimal128, }; @@ -43,10 +45,10 @@ use ::serde::{ Deserialize, }; -const MAX_BSON_SIZE: i32 = 16 * 1024 * 1024; +pub(crate) const MAX_BSON_SIZE: i32 = 16 * 1024 * 1024; pub(crate) const MIN_BSON_DOCUMENT_SIZE: i32 = 4 + 1; // 4 bytes for length, one byte for null terminator -const MIN_BSON_STRING_SIZE: i32 = 4 + 1; // 4 bytes for length, one byte for null terminator -const MIN_CODE_WITH_SCOPE_SIZE: i32 = 4 + MIN_BSON_STRING_SIZE + MIN_BSON_DOCUMENT_SIZE; +pub(crate) const MIN_BSON_STRING_SIZE: i32 = 4 + 1; // 4 bytes for length, one byte for null terminator +pub(crate) const MIN_CODE_WITH_SCOPE_SIZE: i32 = 4 + MIN_BSON_STRING_SIZE + MIN_BSON_DOCUMENT_SIZE; /// Run the provided closure, ensuring that over the course of its execution, exactly `length` bytes /// were read from the reader. @@ -72,7 +74,7 @@ where Ok(()) } -fn read_string(reader: &mut R, utf8_lossy: bool) -> Result { +pub(crate) fn read_string(reader: &mut R, utf8_lossy: bool) -> Result { let len = read_i32(reader)?; // UTF-8 String must have at least 1 byte (the last 0x00). @@ -104,6 +106,18 @@ fn read_string(reader: &mut R, utf8_lossy: bool) -> Result(mut reader: R) -> Result { + let val = read_u8(&mut reader)?; + if val > 1 { + return Err(Error::invalid_value( + Unexpected::Unsigned(val as u64), + &"boolean must be stored as 0 or 1", + )); + } + + Ok(val != 0) +} + fn read_cstring(reader: &mut R) -> Result { let mut v = Vec::new(); @@ -119,7 +133,7 @@ fn read_cstring(reader: &mut R) -> Result { } #[inline] -fn read_u8(reader: &mut R) -> Result { +pub(crate) fn read_u8(reader: &mut R) -> Result { let mut buf = [0; 1]; reader.read_exact(&mut buf)?; Ok(u8::from_le_bytes(buf)) @@ -133,7 +147,7 @@ pub(crate) fn read_i32(reader: &mut R) -> Result { } #[inline] -fn read_i64(reader: &mut R) -> Result { +pub(crate) fn read_i64(reader: &mut R) -> Result { let mut buf = [0; 8]; reader.read_exact(&mut buf)?; Ok(i64::from_le_bytes(buf)) @@ -214,43 +228,7 @@ pub(crate) fn deserialize_bson_kvp( Some(ElementType::String) => read_string(reader, utf8_lossy).map(Bson::String)?, Some(ElementType::EmbeddedDocument) => Document::from_reader(reader).map(Bson::Document)?, Some(ElementType::Array) => deserialize_array(reader, utf8_lossy).map(Bson::Array)?, - Some(ElementType::Binary) => { - let mut len = read_i32(reader)?; - if !(0..=MAX_BSON_SIZE).contains(&len) { - return Err(Error::invalid_length( - len as usize, - &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), - )); - } - let subtype = BinarySubtype::from(read_u8(reader)?); - - // Skip length data in old binary. - if let BinarySubtype::BinaryOld = subtype { - let data_len = read_i32(reader)?; - - if !(0..=(MAX_BSON_SIZE - 4)).contains(&data_len) { - return Err(Error::invalid_length( - data_len as usize, - &format!("0x02 length must be between 0 and {}", MAX_BSON_SIZE - 4) - .as_str(), - )); - } - - if data_len + 4 != len { - return Err(Error::invalid_length( - data_len as usize, - &"0x02 length did not match top level binary length", - )); - } - - len -= 4; - } - - let mut bytes = Vec::with_capacity(len as usize); - - reader.take(len as u64).read_to_end(&mut bytes)?; - Bson::Binary(Binary { subtype, bytes }) - } + Some(ElementType::Binary) => Bson::Binary(Binary::from_reader(reader)?), Some(ElementType::ObjectId) => { let mut objid = [0; 12]; for x in &mut objid { @@ -258,63 +236,20 @@ pub(crate) fn deserialize_bson_kvp( } Bson::ObjectId(oid::ObjectId::from_bytes(objid)) } - Some(ElementType::Boolean) => { - let val = read_u8(reader)?; - if val > 1 { - return Err(Error::invalid_value( - Unexpected::Unsigned(val as u64), - &"boolean must be stored as 0 or 1", - )); - } - - Bson::Boolean(val != 0) - } + Some(ElementType::Boolean) => Bson::Boolean(read_bool(reader)?), Some(ElementType::Null) => Bson::Null, Some(ElementType::RegularExpression) => { - let pattern = read_cstring(reader)?; - - let mut options: Vec<_> = read_cstring(reader)?.chars().collect(); - options.sort_unstable(); - - Bson::RegularExpression(Regex { - pattern, - options: options.into_iter().collect(), - }) + Bson::RegularExpression(Regex::from_reader(reader)?) } Some(ElementType::JavaScriptCode) => { read_string(reader, utf8_lossy).map(Bson::JavaScriptCode)? } Some(ElementType::JavaScriptCodeWithScope) => { - let length = read_i32(reader)?; - if length < MIN_CODE_WITH_SCOPE_SIZE { - return Err(Error::invalid_length( - length as usize, - &format!( - "code with scope length must be at least {}", - MIN_CODE_WITH_SCOPE_SIZE - ) - .as_str(), - )); - } else if length > MAX_BSON_SIZE { - return Err(Error::invalid_length( - length as usize, - &"code with scope length too large", - )); - } - - let mut buf = vec![0u8; (length - 4) as usize]; - reader.read_exact(&mut buf)?; - - let mut slice = buf.as_slice(); - let code = read_string(&mut slice, utf8_lossy)?; - let scope = Document::from_reader(&mut slice)?; - Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { code, scope }) + Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope::from_reader(reader)?) } Some(ElementType::Int32) => read_i32(reader).map(Bson::Int32)?, Some(ElementType::Int64) => read_i64(reader).map(Bson::Int64)?, - Some(ElementType::Timestamp) => { - read_i64(reader).map(|val| Bson::Timestamp(Timestamp::from_le_i64(val)))? - } + Some(ElementType::Timestamp) => Bson::Timestamp(Timestamp::from_reader(reader)?), Some(ElementType::DateTime) => { // The int64 is UTC milliseconds since the Unix epoch. let time = read_i64(reader)?; @@ -323,15 +258,7 @@ pub(crate) fn deserialize_bson_kvp( Some(ElementType::Symbol) => read_string(reader, utf8_lossy).map(Bson::Symbol)?, Some(ElementType::Decimal128) => read_f128(reader).map(Bson::Decimal128)?, Some(ElementType::Undefined) => Bson::Undefined, - Some(ElementType::DbPointer) => { - let namespace = read_string(reader, utf8_lossy)?; - let mut objid = [0; 12]; - reader.read_exact(&mut objid)?; - Bson::DbPointer(DbPointer { - namespace, - id: oid::ObjectId::from_bytes(objid), - }) - } + Some(ElementType::DbPointer) => Bson::DbPointer(DbPointer::from_reader(reader)?), Some(ElementType::MaxKey) => Bson::MaxKey, Some(ElementType::MinKey) => Bson::MinKey, None => { @@ -345,6 +272,122 @@ pub(crate) fn deserialize_bson_kvp( Ok((key, val)) } +impl Binary { + pub(crate) fn from_reader(mut reader: R) -> Result { + let len = read_i32(&mut reader)?; + if !(0..=MAX_BSON_SIZE).contains(&len) { + return Err(Error::invalid_length( + len as usize, + &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), + )); + } + let subtype = BinarySubtype::from(read_u8(&mut reader)?); + Self::from_reader_with_len_and_payload(reader, len, subtype) + } + + pub(crate) fn from_reader_with_len_and_payload( + mut reader: R, + mut len: i32, + subtype: BinarySubtype, + ) -> Result { + if !(0..=MAX_BSON_SIZE).contains(&len) { + return Err(Error::invalid_length( + len as usize, + &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), + )); + } + + // Skip length data in old binary. + if let BinarySubtype::BinaryOld = subtype { + let data_len = read_i32(&mut reader)?; + + if !(0..=(MAX_BSON_SIZE - 4)).contains(&data_len) { + return Err(Error::invalid_length( + data_len as usize, + &format!("0x02 length must be between 0 and {}", MAX_BSON_SIZE - 4).as_str(), + )); + } + + if data_len + 4 != len { + return Err(Error::invalid_length( + data_len as usize, + &"0x02 length did not match top level binary length", + )); + } + + len -= 4; + } + + let mut bytes = Vec::with_capacity(len as usize); + + reader.take(len as u64).read_to_end(&mut bytes)?; + Ok(Binary { subtype, bytes }) + } +} + +impl DbPointer { + pub(crate) fn from_reader(mut reader: R) -> Result { + let ns = read_string(&mut reader, false)?; + let oid = ObjectId::from_reader(&mut reader)?; + Ok(DbPointer { + namespace: ns, + id: oid, + }) + } +} + +impl Regex { + pub(crate) fn from_reader(mut reader: R) -> Result { + let pattern = read_cstring(&mut reader)?; + let options = read_cstring(&mut reader)?; + + Ok(Regex { pattern, options }) + } +} + +impl Timestamp { + pub(crate) fn from_reader(mut reader: R) -> Result { + read_i64(&mut reader).map(Timestamp::from_le_i64) + } +} + +impl ObjectId { + pub(crate) fn from_reader(mut reader: R) -> Result { + let mut buf = [0u8; 12]; + reader.read_exact(&mut buf)?; + Ok(Self::from_bytes(buf)) + } +} + +impl JavaScriptCodeWithScope { + pub(crate) fn from_reader(mut reader: R) -> Result { + let length = read_i32(&mut reader)?; + if length < MIN_CODE_WITH_SCOPE_SIZE { + return Err(Error::invalid_length( + length as usize, + &format!( + "code with scope length must be at least {}", + MIN_CODE_WITH_SCOPE_SIZE + ) + .as_str(), + )); + } else if length > MAX_BSON_SIZE { + return Err(Error::invalid_length( + length as usize, + &"code with scope length too large", + )); + } + + let mut buf = vec![0u8; (length - 4) as usize]; + reader.read_exact(&mut buf)?; + + let mut slice = buf.as_slice(); + let code = read_string(&mut slice, false)?; + let scope = Document::from_reader(&mut slice)?; + Ok(JavaScriptCodeWithScope { code, scope }) + } +} + /// Decode a BSON `Value` into a `T` Deserializable. pub fn from_bson(bson: Bson) -> Result where @@ -361,3 +404,65 @@ where { from_bson(Bson::Document(doc)) } + +fn reader_to_vec(mut reader: R) -> Result> { + let length = read_i32(&mut reader)?; + + if length < MIN_BSON_DOCUMENT_SIZE { + return Err(Error::custom("document size too small")); + } + + let mut bytes = Vec::with_capacity(length as usize); + write_i32(&mut bytes, length).map_err(Error::custom)?; + + reader.take(length as u64 - 4).read_to_end(&mut bytes)?; + Ok(bytes) +} + +/// Deserialize an instance of type `T` from an I/O stream of BSON. +pub fn from_reader(reader: R) -> Result +where + T: DeserializeOwned, + R: Read, +{ + let bytes = reader_to_vec(reader)?; + from_slice(bytes.as_slice()) +} + +/// Deserialize an instance of type `T` from an I/O stream of BSON, replacing any invalid UTF-8 +/// sequences with the Unicode replacement character. +/// +/// This is mainly useful when reading raw BSON returned from a MongoDB server, which +/// in rare cases can contain invalidly truncated strings (https://jira.mongodb.org/browse/SERVER-24007). +/// For most use cases, [`crate::from_reader`] can be used instead. +pub fn from_reader_utf8_lossy(reader: R) -> Result +where + T: DeserializeOwned, + R: Read, +{ + let bytes = reader_to_vec(reader)?; + from_slice_utf8_lossy(bytes.as_slice()) +} + +/// Deserialize an instance of type `T` from a slice of BSON bytes. +pub fn from_slice<'de, T>(bytes: &'de [u8]) -> Result +where + T: Deserialize<'de>, +{ + let mut deserializer = raw::Deserializer::new(bytes, false); + T::deserialize(&mut deserializer) +} + +/// Deserialize an instance of type `T` from a slice of BSON bytes, replacing any invalid UTF-8 +/// sequences with the Unicode replacement character. +/// +/// This is mainly useful when reading raw BSON returned from a MongoDB server, which +/// in rare cases can contain invalidly truncated strings (https://jira.mongodb.org/browse/SERVER-24007). +/// For most use cases, [`crate::from_slice`] can be used instead. +pub fn from_slice_utf8_lossy<'de, T>(bytes: &'de [u8]) -> Result +where + T: Deserialize<'de>, +{ + let mut deserializer = raw::Deserializer::new(bytes, true); + T::deserialize(&mut deserializer) +} diff --git a/src/de/raw.rs b/src/de/raw.rs new file mode 100644 index 000000000..261b9e737 --- /dev/null +++ b/src/de/raw.rs @@ -0,0 +1,977 @@ +use std::{ + borrow::Cow, + io::{ErrorKind, Read}, +}; + +use serde::{ + de::{EnumAccess, Error as SerdeError, IntoDeserializer, VariantAccess}, + forward_to_deserialize_any, + Deserializer as SerdeDeserializer, +}; + +use crate::{ + oid::ObjectId, + spec::{BinarySubtype, ElementType}, + Binary, + Bson, + DateTime, + DbPointer, + Decimal128, + JavaScriptCodeWithScope, + Regex, + Timestamp, +}; + +use super::{ + read_bool, + read_f128, + read_f64, + read_i32, + read_i64, + read_string, + read_u8, + Error, + Result, + MAX_BSON_SIZE, +}; +use crate::de::serde::MapDeserializer; + +/// Deserializer used to parse and deserialize raw BSON bytes. +pub(crate) struct Deserializer<'de> { + bytes: BsonBuf<'de>, + + /// The type of the element currently being deserialized. + /// + /// When the Deserializer is initialized, this will be `ElementType::EmbeddedDocument`, as the + /// only top level type is a document. The "embedded" portion is incorrect in this context, + /// but given that there's no difference between deserializing an embedded document and a + /// top level one, the distinction isn't necessary. + current_type: ElementType, +} + +impl<'de> Deserializer<'de> { + pub(crate) fn new(buf: &'de [u8], utf8_lossy: bool) -> Self { + Self { + bytes: BsonBuf::new(buf, utf8_lossy), + current_type: ElementType::EmbeddedDocument, + } + } + + /// Ensure the entire document was visited, returning an error if not. + /// Will read the trailing null byte if necessary (i.e. the visitor stopped after visiting + /// exactly the number of elements in the document). + fn end_document(&mut self, length_remaining: i32) -> Result<()> { + match length_remaining.cmp(&1) { + std::cmp::Ordering::Equal => { + let nullbyte = read_u8(&mut self.bytes)?; + if nullbyte != 0 { + return Err(Error::custom(format!( + "expected null byte at end of document, got {:#x} instead", + nullbyte + ))); + } + } + std::cmp::Ordering::Greater => { + return Err(Error::custom(format!( + "document has bytes remaining that were not visited: {}", + length_remaining + ))); + } + std::cmp::Ordering::Less => { + if length_remaining < 0 { + return Err(Error::custom("length of document was too short")); + } + } + } + Ok(()) + } + + /// Read a string from the BSON. + /// + /// If utf8_lossy, this will be an owned string if invalid UTF-8 is encountered in the string, + /// otherwise it will be borrowed. + fn deserialize_str(&mut self) -> Result> { + self.bytes.read_str() + } + + fn deserialize_document_key(&mut self) -> Result> { + self.bytes.read_cstr() + } + + /// Construct a `DocumentAccess` and pass it into the provided closure, returning the + /// result of the closure if no other errors are encountered. + fn deserialize_document(&mut self, f: F) -> Result + where + F: FnOnce(DocumentAccess<'_, 'de>) -> Result, + { + let mut length_remaining = read_i32(&mut self.bytes)? - 4; + let out = f(DocumentAccess { + root_deserializer: self, + length_remaining: &mut length_remaining, + }); + + if out.is_ok() { + self.end_document(length_remaining)?; + } + out + } + + /// Deserialize the next element type and update `current_type` accordingly. + /// Returns `None` if a null byte is read. + fn deserialize_next_type(&mut self) -> Result> { + let tag = read_u8(&mut self.bytes)?; + if tag == 0 { + return Ok(None); + } + + let element_type = ElementType::from(tag) + .ok_or_else(|| Error::custom(format!("invalid element type: {}", tag)))?; + + self.current_type = element_type; + Ok(Some(element_type)) + } +} + +impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut Deserializer<'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.current_type { + ElementType::Int32 => visitor.visit_i32(read_i32(&mut self.bytes)?), + ElementType::Int64 => visitor.visit_i64(read_i64(&mut self.bytes)?), + ElementType::Double => visitor.visit_f64(read_f64(&mut self.bytes)?), + ElementType::String => match self.deserialize_str()? { + Cow::Borrowed(s) => visitor.visit_borrowed_str(s), + Cow::Owned(string) => visitor.visit_string(string), + }, + ElementType::Boolean => visitor.visit_bool(read_bool(&mut self.bytes)?), + ElementType::Null => visitor.visit_unit(), + ElementType::ObjectId => { + let oid = ObjectId::from_reader(&mut self.bytes)?; + visitor.visit_map(ObjectIdAccess::new(oid)) + } + ElementType::EmbeddedDocument => { + self.deserialize_document(|access| visitor.visit_map(access)) + } + ElementType::Array => self.deserialize_document(|access| visitor.visit_seq(access)), + ElementType::Binary => { + let len = read_i32(&mut self.bytes)?; + if !(0..=MAX_BSON_SIZE).contains(&len) { + return Err(Error::invalid_length( + len as usize, + &format!("binary length must be between 0 and {}", MAX_BSON_SIZE).as_str(), + )); + } + let subtype = BinarySubtype::from(read_u8(&mut self.bytes)?); + match subtype { + BinarySubtype::Generic => { + visitor.visit_borrowed_bytes(self.bytes.read_slice(len as usize)?) + } + _ => { + let binary = Binary::from_reader_with_len_and_payload( + &mut self.bytes, + len, + subtype, + )?; + let mut d = BinaryDeserializer::new(binary); + visitor.visit_map(BinaryAccess { + deserializer: &mut d, + }) + } + } + } + ElementType::Undefined => { + let doc = Bson::Undefined.into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + ElementType::DateTime => { + let dti = read_i64(&mut self.bytes)?; + let dt = DateTime::from_millis(dti); + let mut d = DateTimeDeserializer::new(dt); + visitor.visit_map(DateTimeAccess { + deserializer: &mut d, + }) + } + ElementType::RegularExpression => { + let doc = Bson::RegularExpression(Regex::from_reader(&mut self.bytes)?) + .into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + ElementType::DbPointer => { + let doc = Bson::DbPointer(DbPointer::from_reader(&mut self.bytes)?) + .into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + ElementType::JavaScriptCode => { + let code = read_string(&mut self.bytes, false)?; + let doc = Bson::JavaScriptCode(code).into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + ElementType::JavaScriptCodeWithScope => { + let code_w_scope = JavaScriptCodeWithScope::from_reader(&mut self.bytes)?; + let doc = Bson::JavaScriptCodeWithScope(code_w_scope).into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + ElementType::Symbol => { + let symbol = read_string(&mut self.bytes, false)?; + let doc = Bson::Symbol(symbol).into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + ElementType::Timestamp => { + let ts = Timestamp::from_reader(&mut self.bytes)?; + let mut d = TimestampDeserializer::new(ts); + visitor.visit_map(TimestampAccess { + deserializer: &mut d, + }) + } + ElementType::Decimal128 => { + let d128 = read_f128(&mut self.bytes)?; + visitor.visit_map(Decimal128Access::new(d128)) + } + ElementType::MaxKey => { + let doc = Bson::MaxKey.into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + ElementType::MinKey => { + let doc = Bson::MinKey.into_extended_document(); + visitor.visit_map(MapDeserializer::new(doc)) + } + } + } + + #[inline] + fn deserialize_option(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.current_type { + ElementType::Null => visitor.visit_none(), + _ => visitor.visit_some(self), + } + } + + fn deserialize_enum( + self, + _name: &str, + _variants: &'static [&'static str], + visitor: V, + ) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.current_type { + ElementType::String => visitor.visit_enum(self.deserialize_str()?.into_deserializer()), + ElementType::EmbeddedDocument => { + self.deserialize_document(|access| visitor.visit_enum(access)) + } + t => Err(Error::custom(format!("expected enum, instead got {:?}", t))), + } + } + + fn is_human_readable(&self) -> bool { + false + } + + forward_to_deserialize_any! { + bool char str bytes byte_buf unit unit_struct string + identifier newtype_struct seq tuple tuple_struct struct + map ignored_any i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 + } +} + +/// Struct for accessing documents for deserialization purposes. +/// This is used to deserialize maps, structs, sequences, and enums. +struct DocumentAccess<'d, 'de> { + root_deserializer: &'d mut Deserializer<'de>, + length_remaining: &'d mut i32, +} + +impl<'d, 'de> DocumentAccess<'d, 'de> { + /// Read the next element type and update the root deserializer with it. + /// + /// Returns `Ok(None)` if the document has been fully read and has no more elements. + fn read_next_type(&mut self) -> Result> { + let t = self.read(|s| s.root_deserializer.deserialize_next_type())?; + + if t.is_none() && *self.length_remaining != 0 { + return Err(Error::custom(format!( + "got null byte but still have length {} remaining", + self.length_remaining + ))); + } + + Ok(t) + } + + /// Executes a closure that reads from the BSON bytes and returns an error if the number of + /// bytes read exceeds length_remaining. + /// + /// A mutable reference to this `DocumentAccess` is passed into the closure. + fn read(&mut self, f: F) -> Result + where + F: FnOnce(&mut Self) -> Result, + { + let start_bytes = self.root_deserializer.bytes.bytes_read(); + let out = f(self); + let bytes_read = self.root_deserializer.bytes.bytes_read() - start_bytes; + *self.length_remaining -= bytes_read as i32; + + if *self.length_remaining < 0 { + return Err(Error::custom("length of document too short")); + } + out + } + + /// Read the next value from the document. + fn read_next_value(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + self.read(|s| seed.deserialize(&mut *s.root_deserializer)) + } +} + +impl<'d, 'de> serde::de::MapAccess<'de> for DocumentAccess<'d, 'de> { + type Error = crate::de::Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if self.read_next_type()?.is_none() { + return Ok(None); + } + + self.read(|s| { + seed.deserialize(DocumentKeyDeserializer { + root_deserializer: &mut *s.root_deserializer, + }) + }) + .map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + self.read_next_value(seed) + } +} + +impl<'d, 'de> serde::de::SeqAccess<'de> for DocumentAccess<'d, 'de> { + type Error = Error; + + fn next_element_seed(&mut self, seed: S) -> Result> + where + S: serde::de::DeserializeSeed<'de>, + { + if self.read_next_type()?.is_none() { + return Ok(None); + } + let _index = self.read(|s| s.root_deserializer.deserialize_document_key())?; + self.read_next_value(seed).map(Some) + } +} + +impl<'d, 'de> EnumAccess<'de> for DocumentAccess<'d, 'de> { + type Error = Error; + type Variant = Self; + + fn variant_seed(mut self, seed: V) -> Result<(V::Value, Self::Variant)> + where + V: serde::de::DeserializeSeed<'de>, + { + if self.read_next_type()?.is_none() { + return Err(Error::EndOfStream); + } + + let key = self.read(|s| { + seed.deserialize(DocumentKeyDeserializer { + root_deserializer: &mut *s.root_deserializer, + }) + })?; + + Ok((key, self)) + } +} + +impl<'d, 'de> VariantAccess<'de> for DocumentAccess<'d, 'de> { + type Error = Error; + + fn unit_variant(self) -> Result<()> { + Err(Error::custom( + "expected a string enum, got a document instead", + )) + } + + fn newtype_variant_seed(mut self, seed: S) -> Result + where + S: serde::de::DeserializeSeed<'de>, + { + self.read_next_value(seed) + } + + fn tuple_variant(mut self, _len: usize, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.read(|s| s.root_deserializer.deserialize_seq(visitor)) + } + + fn struct_variant(mut self, _fields: &'static [&'static str], visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + self.read(|s| s.root_deserializer.deserialize_map(visitor)) + } +} + +/// Deserializer used specifically for deserializing a document's cstring keys. +struct DocumentKeyDeserializer<'d, 'de> { + root_deserializer: &'d mut Deserializer<'de>, +} + +impl<'d, 'de> serde::de::Deserializer<'de> for DocumentKeyDeserializer<'d, 'de> { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + let s = self.root_deserializer.deserialize_document_key()?; + match s { + Cow::Borrowed(b) => visitor.visit_borrowed_str(b), + Cow::Owned(string) => visitor.visit_string(string), + } + } + + forward_to_deserialize_any! { + bool char str bytes byte_buf option unit unit_struct string + identifier newtype_struct seq tuple tuple_struct struct map enum + ignored_any i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 + } +} + +/// Deserializer used to deserialize the given field name without any copies. +struct FieldDeserializer { + field_name: &'static str, +} + +impl<'de> serde::de::Deserializer<'de> for FieldDeserializer { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_borrowed_str(self.field_name) + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +struct ObjectIdAccess { + oid: ObjectId, + visited: bool, +} + +impl ObjectIdAccess { + fn new(oid: ObjectId) -> Self { + Self { + oid, + visited: false, + } + } +} + +impl<'de> serde::de::MapAccess<'de> for ObjectIdAccess { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if self.visited { + return Ok(None); + } + self.visited = true; + seed.deserialize(FieldDeserializer { field_name: "$oid" }) + .map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(ObjectIdDeserializer(self.oid)) + } +} + +struct ObjectIdDeserializer(ObjectId); + +impl<'de> serde::de::Deserializer<'de> for ObjectIdDeserializer { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + visitor.visit_string(self.0.to_hex()) + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +pub(crate) struct Decimal128Access { + decimal: Decimal128, + visited: bool, +} + +impl Decimal128Access { + pub(crate) fn new(decimal: Decimal128) -> Self { + Self { + decimal, + visited: false, + } + } +} + +impl<'de> serde::de::MapAccess<'de> for Decimal128Access { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + if self.visited { + return Ok(None); + } + self.visited = true; + seed.deserialize(FieldDeserializer { + field_name: "$numberDecimalBytes", + }) + .map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(Decimal128Deserializer(self.decimal.clone())) + } +} + +struct Decimal128Deserializer(Decimal128); + +impl<'de> serde::de::Deserializer<'de> for Decimal128Deserializer { + type Error = Error; + + fn deserialize_any(self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + #[cfg(not(feature = "decimal128"))] + { + visitor.visit_bytes(&self.0.bytes) + } + + #[cfg(feature = "decimal128")] + { + visitor.visit_bytes(&self.0.to_raw_bytes_le()) + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +enum TimestampDeserializationStage { + TopLevel, + Time, + Increment, + Done, +} + +struct TimestampAccess<'d> { + deserializer: &'d mut TimestampDeserializer, +} + +impl<'de, 'd> serde::de::MapAccess<'de> for TimestampAccess<'d> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + match self.deserializer.stage { + TimestampDeserializationStage::TopLevel => seed + .deserialize(FieldDeserializer { + field_name: "$timestamp", + }) + .map(Some), + TimestampDeserializationStage::Time => seed + .deserialize(FieldDeserializer { field_name: "t" }) + .map(Some), + TimestampDeserializationStage::Increment => seed + .deserialize(FieldDeserializer { field_name: "i" }) + .map(Some), + TimestampDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct TimestampDeserializer { + ts: Timestamp, + stage: TimestampDeserializationStage, +} + +impl TimestampDeserializer { + fn new(ts: Timestamp) -> Self { + Self { + ts, + stage: TimestampDeserializationStage::TopLevel, + } + } +} + +impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut TimestampDeserializer { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + TimestampDeserializationStage::TopLevel => { + self.stage = TimestampDeserializationStage::Time; + visitor.visit_map(TimestampAccess { + deserializer: &mut self, + }) + } + TimestampDeserializationStage::Time => { + self.stage = TimestampDeserializationStage::Increment; + visitor.visit_u32(self.ts.time) + } + TimestampDeserializationStage::Increment => { + self.stage = TimestampDeserializationStage::Done; + visitor.visit_u32(self.ts.increment) + } + TimestampDeserializationStage::Done => { + Err(Error::custom("timestamp fully deserialized already")) + } + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +enum DateTimeDeserializationStage { + TopLevel, + NumberLong, + Done, +} + +struct DateTimeAccess<'d> { + deserializer: &'d mut DateTimeDeserializer, +} + +impl<'de, 'd> serde::de::MapAccess<'de> for DateTimeAccess<'d> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + match self.deserializer.stage { + DateTimeDeserializationStage::TopLevel => seed + .deserialize(FieldDeserializer { + field_name: "$date", + }) + .map(Some), + DateTimeDeserializationStage::NumberLong => seed + .deserialize(FieldDeserializer { + field_name: "$numberLong", + }) + .map(Some), + DateTimeDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct DateTimeDeserializer { + dt: DateTime, + stage: DateTimeDeserializationStage, +} + +impl DateTimeDeserializer { + fn new(dt: DateTime) -> Self { + Self { + dt, + stage: DateTimeDeserializationStage::TopLevel, + } + } +} + +impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut DateTimeDeserializer { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + DateTimeDeserializationStage::TopLevel => { + self.stage = DateTimeDeserializationStage::NumberLong; + visitor.visit_map(DateTimeAccess { + deserializer: &mut self, + }) + } + DateTimeDeserializationStage::NumberLong => { + self.stage = DateTimeDeserializationStage::Done; + visitor.visit_string(self.dt.timestamp_millis().to_string()) + } + DateTimeDeserializationStage::Done => { + Err(Error::custom("DateTime fully deserialized already")) + } + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +struct BinaryAccess<'d> { + deserializer: &'d mut BinaryDeserializer, +} + +impl<'de, 'd> serde::de::MapAccess<'de> for BinaryAccess<'d> { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: serde::de::DeserializeSeed<'de>, + { + match self.deserializer.stage { + BinaryDeserializationStage::TopLevel => seed + .deserialize(FieldDeserializer { + field_name: "$binary", + }) + .map(Some), + BinaryDeserializationStage::Subtype => seed + .deserialize(FieldDeserializer { + field_name: "subType", + }) + .map(Some), + BinaryDeserializationStage::Bytes => seed + .deserialize(FieldDeserializer { + field_name: "base64", + }) + .map(Some), + BinaryDeserializationStage::Done => Ok(None), + } + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: serde::de::DeserializeSeed<'de>, + { + seed.deserialize(&mut *self.deserializer) + } +} + +struct BinaryDeserializer { + binary: Binary, + stage: BinaryDeserializationStage, +} + +impl BinaryDeserializer { + fn new(binary: Binary) -> Self { + Self { + binary, + stage: BinaryDeserializationStage::TopLevel, + } + } +} + +impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BinaryDeserializer { + type Error = Error; + + fn deserialize_any(mut self, visitor: V) -> Result + where + V: serde::de::Visitor<'de>, + { + match self.stage { + BinaryDeserializationStage::TopLevel => { + self.stage = BinaryDeserializationStage::Subtype; + visitor.visit_map(BinaryAccess { + deserializer: &mut self, + }) + } + BinaryDeserializationStage::Subtype => { + self.stage = BinaryDeserializationStage::Bytes; + visitor.visit_string(hex::encode([u8::from(self.binary.subtype)])) + } + BinaryDeserializationStage::Bytes => { + self.stage = BinaryDeserializationStage::Done; + visitor.visit_string(base64::encode(self.binary.bytes.as_slice())) + } + BinaryDeserializationStage::Done => { + Err(Error::custom("Binary fully deserialized already")) + } + } + } + + serde::forward_to_deserialize_any! { + bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq + bytes byte_buf map struct option unit newtype_struct + ignored_any unit_struct tuple_struct tuple enum identifier + } +} + +enum BinaryDeserializationStage { + TopLevel, + Subtype, + Bytes, + Done, +} + +/// Struct wrapping a slice of BSON bytes. +struct BsonBuf<'a> { + bytes: &'a [u8], + index: usize, + + /// Whether or not to insert replacement characters in place of invalid UTF-8 sequences when + /// deserializing strings. + utf8_lossy: bool, +} + +impl<'a> Read for BsonBuf<'a> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.index_check()?; + let bytes_read = self.bytes[self.index..].as_ref().read(buf)?; + self.index += bytes_read; + Ok(bytes_read) + } +} + +impl<'a> BsonBuf<'a> { + fn new(bytes: &'a [u8], utf8_lossy: bool) -> Self { + Self { + bytes, + index: 0, + utf8_lossy, + } + } + + fn bytes_read(&self) -> usize { + self.index + } + + /// Verify the index has not run out of bounds. + fn index_check(&self) -> std::io::Result<()> { + if self.index >= self.bytes.len() { + return Err(ErrorKind::UnexpectedEof.into()); + } + Ok(()) + } + + /// Get the starting at the provided index and ending at the buffer's current index. + fn str(&mut self, start: usize) -> Result> { + let bytes = &self.bytes[start..self.index]; + let s = if self.utf8_lossy { + String::from_utf8_lossy(bytes) + } else { + Cow::Borrowed(std::str::from_utf8(bytes).map_err(Error::custom)?) + }; + + // consume the null byte + if self.bytes[self.index] != 0 { + return Err(Error::custom("string was not null-terminated")); + } + self.index += 1; + self.index_check()?; + + Ok(s) + } + + /// Attempts to read a null-terminated UTF-8 cstring from the data. + /// + /// If utf8_lossy and invalid UTF-8 is encountered, the unicode replacement character will be + /// inserted in place of the offending data, resulting in an owned `String`. Otherwise, the + /// data will be borrowed as-is. + fn read_cstr(&mut self) -> Result> { + let start = self.index; + while self.index < self.bytes.len() && self.bytes[self.index] != 0 { + self.index += 1 + } + + self.index_check()?; + + self.str(start) + } + + /// Attempts to read a null-terminated UTF-8 string from the data. + /// + /// If invalid UTF-8 is encountered, the unicode replacement character will be inserted in place + /// of the offending data, resulting in an owned `String`. Otherwise, the data will be + /// borrowed as-is. + fn read_str(&mut self) -> Result> { + let len = read_i32(self)?; + let start = self.index; + + // UTF-8 String must have at least 1 byte (the last 0x00). + if len < 1 { + return Err(Error::invalid_length( + len as usize, + &"UTF-8 string must have at least 1 byte", + )); + } + + self.index += (len - 1) as usize; + self.index_check()?; + + self.str(start) + } + + fn read_slice(&mut self, length: usize) -> Result<&'a [u8]> { + let start = self.index; + self.index += length; + self.index_check()?; + Ok(&self.bytes[start..self.index]) + } +} diff --git a/src/de/serde.rs b/src/de/serde.rs index fbc6f53e8..7831a900c 100644 --- a/src/de/serde.rs +++ b/src/de/serde.rs @@ -1,4 +1,8 @@ -use std::{convert::TryFrom, fmt, vec}; +use std::{ + convert::{TryFrom, TryInto}, + fmt, + vec, +}; use serde::de::{ self, @@ -13,17 +17,19 @@ use serde::de::{ VariantAccess, Visitor, }; +use serde_bytes::ByteBuf; -#[cfg(feature = "decimal128")] -use crate::decimal128::Decimal128; use crate::{ bson::{Binary, Bson, DbPointer, JavaScriptCodeWithScope, Regex, Timestamp}, datetime::DateTime, - document::{Document, DocumentVisitor, IntoIter}, + document::{Document, IntoIter}, oid::ObjectId, spec::BinarySubtype, + Decimal128, }; +use super::raw::Decimal128Access; + pub(crate) struct BsonVisitor; impl<'de> Deserialize<'de> for ObjectId { @@ -212,13 +218,212 @@ impl<'de> Visitor<'de> for BsonVisitor { Ok(Bson::Array(values)) } - #[inline] - fn visit_map(self, visitor: V) -> Result + fn visit_map(self, mut visitor: V) -> Result where V: MapAccess<'de>, { - let values = DocumentVisitor::new().visit_map(visitor)?; - Ok(Bson::from_extended_document(values)) + use crate::extjson; + + let mut doc = Document::new(); + + while let Some(k) = visitor.next_key::()? { + match k.as_str() { + "$oid" => { + let hex: String = visitor.next_value()?; + return Ok(Bson::ObjectId(ObjectId::parse_str(hex.as_str()).map_err( + |_| { + V::Error::invalid_value( + Unexpected::Str(&hex), + &"24-character, big-endian hex string", + ) + }, + )?)); + } + "$symbol" => { + let string: String = visitor.next_value()?; + return Ok(Bson::Symbol(string)); + } + + "$numberInt" => { + let string: String = visitor.next_value()?; + return Ok(Bson::Int32(string.parse().map_err(|_| { + V::Error::invalid_value( + Unexpected::Str(&string), + &"32-bit signed integer as a string", + ) + })?)); + } + + "$numberLong" => { + let string: String = visitor.next_value()?; + return Ok(Bson::Int64(string.parse().map_err(|_| { + V::Error::invalid_value( + Unexpected::Str(&string), + &"64-bit signed integer as a string", + ) + })?)); + } + + "$numberDouble" => { + let string: String = visitor.next_value()?; + let val = match string.as_str() { + "Infinity" => Bson::Double(std::f64::INFINITY), + "-Infinity" => Bson::Double(std::f64::NEG_INFINITY), + "NaN" => Bson::Double(std::f64::NAN), + _ => Bson::Int64(string.parse().map_err(|_| { + V::Error::invalid_value( + Unexpected::Str(&string), + &"64-bit signed integer as a string", + ) + })?), + }; + return Ok(val); + } + + "$binary" => { + let v = visitor.next_value::()?; + return Ok(Bson::Binary( + extjson::models::Binary { body: v } + .parse() + .map_err(Error::custom)?, + )); + } + + "$code" => { + let code = visitor.next_value::()?; + if let Some(key) = visitor.next_key::()? { + if key.as_str() == "$scope" { + let scope = visitor.next_value::()?; + return Ok(Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { + code, + scope, + })); + } else { + return Err(Error::unknown_field(key.as_str(), &["$scope"])); + } + } else { + return Ok(Bson::JavaScriptCode(code)); + } + } + + "$scope" => { + let scope = visitor.next_value::()?; + if let Some(key) = visitor.next_key::()? { + if key.as_str() == "$code" { + let code = visitor.next_value::()?; + return Ok(Bson::JavaScriptCodeWithScope(JavaScriptCodeWithScope { + code, + scope, + })); + } else { + return Err(Error::unknown_field(key.as_str(), &["$code"])); + } + } else { + return Err(Error::missing_field("$code")); + } + } + + "$timestamp" => { + let ts = visitor.next_value::()?; + return Ok(Bson::Timestamp(Timestamp { + time: ts.t, + increment: ts.i, + })); + } + + "$regularExpression" => { + let re = visitor.next_value::()?; + return Ok(Bson::RegularExpression(Regex { + pattern: re.pattern, + options: re.options, + })); + } + + "$dbPointer" => { + let dbp = visitor.next_value::()?; + return Ok(Bson::DbPointer(DbPointer { + id: dbp.id.parse().map_err(Error::custom)?, + namespace: dbp.ref_ns, + })); + } + + "$date" => { + let dt = visitor.next_value::()?; + return Ok(Bson::DateTime( + extjson::models::DateTime { body: dt } + .parse() + .map_err(Error::custom)?, + )); + } + + "$maxKey" => { + let i = visitor.next_value::()?; + return extjson::models::MaxKey { value: i } + .parse() + .map_err(Error::custom); + } + + "$minKey" => { + let i = visitor.next_value::()?; + return extjson::models::MinKey { value: i } + .parse() + .map_err(Error::custom); + } + + "$undefined" => { + let b = visitor.next_value::()?; + return extjson::models::Undefined { value: b } + .parse() + .map_err(Error::custom); + } + + "$numberDecimal" => { + #[cfg(not(feature = "decimal128"))] + { + return Err(Error::custom(format!( + "enable the experimental decimal128 feature flag to deserialize \ + decimal128 from string" + ))); + } + + #[cfg(feature = "decimal128")] + { + let s = visitor.next_value::()?; + return Ok(Bson::Decimal128(s.parse().map_err(|_| { + Error::custom(format!("malformatted decimal128 string: {}", s)) + })?)); + } + } + + "$numberDecimalBytes" => { + let bytes = visitor.next_value::()?; + let arr = bytes.into_vec().try_into().map_err(|v: Vec| { + Error::custom(format!( + "expected decimal128 as byte buffer, instead got buffer of length {}", + v.len() + )) + })?; + #[cfg(not(feature = "decimal128"))] + { + return Ok(Bson::Decimal128(Decimal128 { bytes: arr })); + } + + #[cfg(feature = "decimal128")] + { + unsafe { + return Ok(Bson::Decimal128(Decimal128::from_raw_bytes_le(arr))); + } + } + } + + _ => { + let v = visitor.next_value::()?; + doc.insert(k, v); + } + } + } + + Ok(Bson::Document(doc)) } #[inline] @@ -231,6 +436,17 @@ impl<'de> Visitor<'de> for BsonVisitor { bytes: v.to_vec(), })) } + + #[inline] + fn visit_byte_buf(self, v: Vec) -> Result + where + E: Error, + { + Ok(Bson::Binary(Binary { + subtype: BinarySubtype::Generic, + bytes: v, + })) + } } fn convert_unsigned_to_signed(value: u64) -> Result @@ -334,15 +550,16 @@ impl<'de> de::Deserializer<'de> for Deserializer { Bson::Int64(v) => visitor.visit_i64(v), Bson::Binary(Binary { subtype: BinarySubtype::Generic, - ref bytes, - }) => visitor.visit_bytes(bytes), + bytes, + }) => visitor.visit_byte_buf(bytes), binary @ Bson::Binary(..) => visitor.visit_map(MapDeserializer { - iter: binary.to_extended_document().into_iter(), + iter: binary.into_extended_document().into_iter(), value: None, len: 2, }), + Bson::Decimal128(d) => visitor.visit_map(Decimal128Access::new(d)), _ => { - let doc = value.to_extended_document(); + let doc = value.into_extended_document(); let len = doc.len(); visitor.visit_map(MapDeserializer { iter: doc.into_iter(), @@ -621,10 +838,21 @@ impl<'de> SeqAccess<'de> for SeqDeserializer { } } -struct MapDeserializer { - iter: IntoIter, - value: Option, - len: usize, +pub(crate) struct MapDeserializer { + pub(crate) iter: IntoIter, + pub(crate) value: Option, + pub(crate) len: usize, +} + +impl MapDeserializer { + pub(crate) fn new(doc: Document) -> Self { + let len = doc.len(); + MapDeserializer { + iter: doc.into_iter(), + len, + value: None, + } + } } impl<'de> MapAccess<'de> for MapDeserializer { diff --git a/src/document.rs b/src/document.rs index 2a4cba084..f9e1b6924 100644 --- a/src/document.rs +++ b/src/document.rs @@ -5,13 +5,12 @@ use std::{ fmt::{self, Debug, Display, Formatter}, io::{Read, Write}, iter::{Extend, FromIterator, IntoIterator}, - marker::PhantomData, mem, }; use ahash::RandomState; use indexmap::IndexMap; -use serde::de::{self, Error, MapAccess, Visitor}; +use serde::de::Error; #[cfg(feature = "decimal128")] use crate::decimal128::Decimal128; @@ -684,52 +683,6 @@ impl<'a> OccupiedEntry<'a> { } } -pub(crate) struct DocumentVisitor { - marker: PhantomData, -} - -impl DocumentVisitor { - #[allow(clippy::new_without_default)] - pub(crate) fn new() -> DocumentVisitor { - DocumentVisitor { - marker: PhantomData, - } - } -} - -impl<'de> Visitor<'de> for DocumentVisitor { - type Value = Document; - - fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "expecting ordered document") - } - - #[inline] - fn visit_unit(self) -> Result - where - E: de::Error, - { - Ok(Document::new()) - } - - #[inline] - fn visit_map(self, mut visitor: V) -> Result - where - V: MapAccess<'de>, - { - let mut inner = match visitor.size_hint() { - Some(size) => IndexMap::with_capacity_and_hasher(size, RandomState::default()), - None => IndexMap::default(), - }; - - while let Some((key, value)) = visitor.next_entry()? { - inner.insert(key, value); - } - - Ok(Document { inner }) - } -} - impl Extend<(String, Bson)> for Document { fn extend>(&mut self, iter: T) { for (k, v) in iter { diff --git a/src/extjson/mod.rs b/src/extjson/mod.rs index 18b40ecc6..1ad322cd7 100644 --- a/src/extjson/mod.rs +++ b/src/extjson/mod.rs @@ -90,4 +90,4 @@ //! ``` pub mod de; -mod models; +pub(crate) mod models; diff --git a/src/extjson/models.rs b/src/extjson/models.rs index 4a9232fa8..551d58135 100644 --- a/src/extjson/models.rs +++ b/src/extjson/models.rs @@ -102,9 +102,9 @@ pub(crate) struct Regex { #[derive(Deserialize)] #[serde(deny_unknown_fields)] -struct RegexBody { - pattern: String, - options: String, +pub(crate) struct RegexBody { + pub(crate) pattern: String, + pub(crate) options: String, } impl Regex { @@ -124,15 +124,16 @@ impl Regex { #[serde(deny_unknown_fields)] pub(crate) struct Binary { #[serde(rename = "$binary")] - body: BinaryBody, + pub(crate) body: BinaryBody, } #[derive(Deserialize)] #[serde(deny_unknown_fields)] -struct BinaryBody { - base64: String, +pub(crate) struct BinaryBody { + pub(crate) base64: String, + #[serde(rename = "subType")] - subtype: String, + pub(crate) subtype: String, } impl Binary { @@ -209,8 +210,8 @@ pub(crate) struct Timestamp { #[derive(Deserialize)] #[serde(deny_unknown_fields)] pub(crate) struct TimestampBody { - t: u32, - i: u32, + pub(crate) t: u32, + pub(crate) i: u32, } impl Timestamp { @@ -226,12 +227,12 @@ impl Timestamp { #[serde(deny_unknown_fields)] pub(crate) struct DateTime { #[serde(rename = "$date")] - body: DateTimeBody, + pub(crate) body: DateTimeBody, } #[derive(Deserialize)] #[serde(untagged)] -enum DateTimeBody { +pub(crate) enum DateTimeBody { Canonical(Int64), Relaxed(String), } @@ -263,7 +264,7 @@ impl DateTime { #[serde(deny_unknown_fields)] pub(crate) struct MinKey { #[serde(rename = "$minKey")] - value: u8, + pub(crate) value: u8, } impl MinKey { @@ -283,7 +284,7 @@ impl MinKey { #[serde(deny_unknown_fields)] pub(crate) struct MaxKey { #[serde(rename = "$maxKey")] - value: u8, + pub(crate) value: u8, } impl MaxKey { @@ -308,12 +309,12 @@ pub(crate) struct DbPointer { #[derive(Deserialize)] #[serde(deny_unknown_fields)] -struct DbPointerBody { +pub(crate) struct DbPointerBody { #[serde(rename = "$ref")] - ref_ns: String, + pub(crate) ref_ns: String, #[serde(rename = "$id")] - id: ObjectId, + pub(crate) id: ObjectId, } impl DbPointer { @@ -350,7 +351,7 @@ impl Decimal128 { #[serde(deny_unknown_fields)] pub(crate) struct Undefined { #[serde(rename = "$undefined")] - value: bool, + pub(crate) value: bool, } impl Undefined { diff --git a/src/lib.rs b/src/lib.rs index 9f7f61f26..5b9f64eb7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -188,7 +188,15 @@ pub use self::{ bson::{Array, Binary, Bson, DbPointer, Document, JavaScriptCodeWithScope, Regex, Timestamp}, datetime::DateTime, - de::{from_bson, from_document, Deserializer}, + de::{ + from_bson, + from_document, + from_reader, + from_reader_utf8_lossy, + from_slice, + from_slice_utf8_lossy, + Deserializer, + }, decimal128::Decimal128, ser::{to_bson, to_document, Serializer}, }; diff --git a/src/ser/mod.rs b/src/ser/mod.rs index 556104fef..a26690072 100644 --- a/src/ser/mod.rs +++ b/src/ser/mod.rs @@ -29,7 +29,7 @@ pub use self::{ serde::Serializer, }; -use std::{io::Write, mem}; +use std::{io::Write, iter::FromIterator, mem}; #[cfg(feature = "decimal128")] use crate::decimal128::Decimal128; @@ -119,7 +119,11 @@ pub(crate) fn serialize_bson( ref options, }) => { write_cstring(writer, pattern)?; - write_cstring(writer, options) + + let mut chars: Vec = options.chars().collect(); + chars.sort_unstable(); + + write_cstring(writer, String::from_iter(chars).as_str()) } Bson::JavaScriptCode(ref code) => write_string(writer, code), Bson::ObjectId(ref id) => writer.write_all(&id.bytes()).map_err(From::from), diff --git a/src/ser/serde.rs b/src/ser/serde.rs index 0638974a8..e903b1e26 100644 --- a/src/ser/serde.rs +++ b/src/ser/serde.rs @@ -67,7 +67,7 @@ impl Serialize for Bson { ref bytes, }) => serializer.serialize_bytes(bytes), _ => { - let doc = self.to_extended_document(); + let doc = self.clone().into_extended_document(); doc.serialize(serializer) } } diff --git a/src/tests/spec/corpus.rs b/src/tests/spec/corpus.rs index f1e7ee5f1..a1a574b85 100644 --- a/src/tests/spec/corpus.rs +++ b/src/tests/spec/corpus.rs @@ -59,36 +59,119 @@ fn run_test(test: TestFile) { for valid in test.valid { let description = format!("{}: {}", test.description, valid.description); - let bson_to_native_cb = Document::from_reader( - &mut hex::decode(&valid.canonical_bson) - .expect(&description) - .as_slice(), - ) - .expect(&description); - - let mut native_to_bson_bson_to_native_cv = Vec::new(); + let canonical_bson = hex::decode(&valid.canonical_bson).expect(&description); + + let bson_to_native_cb = + Document::from_reader(canonical_bson.as_slice()).expect(&description); + + let bson_to_native_cb_serde: Document = + crate::from_reader(canonical_bson.as_slice()).expect(&description); + + let native_to_native_cb_serde: Document = + crate::from_document(bson_to_native_cb.clone()).expect(&description); + + let mut native_to_bson_bson_to_native_cb = Vec::new(); bson_to_native_cb - .to_writer(&mut native_to_bson_bson_to_native_cv) + .to_writer(&mut native_to_bson_bson_to_native_cb) .expect(&description); - // TODO RUST-36: Enable decimal128 tests. - // extJSON not implemented for decimal128 without the feature flag, so we must stop here. - if test.bson_type == "0x13" && !cfg!(feature = "decimal128") { - continue; - } + let mut native_to_bson_bson_to_native_cb_serde = Vec::new(); + bson_to_native_cb_serde + .to_writer(&mut native_to_bson_bson_to_native_cb_serde) + .expect(&description); - let cej: serde_json::Value = - serde_json::from_str(&valid.canonical_extjson).expect(&description); + let mut native_to_bson_native_to_native_cb_serde = Vec::new(); + native_to_native_cb_serde + .to_writer(&mut native_to_bson_native_to_native_cb_serde) + .expect(&description); // native_to_bson( bson_to_native(cB) ) = cB assert_eq!( - hex::encode(native_to_bson_bson_to_native_cv).to_lowercase(), + hex::encode(native_to_bson_bson_to_native_cb).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + + assert_eq!( + hex::encode(native_to_bson_bson_to_native_cb_serde).to_lowercase(), valid.canonical_bson.to_lowercase(), "{}", description, ); + assert_eq!( + hex::encode(native_to_bson_native_to_native_cb_serde).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + + // NaN == NaN is false, so we skip document comparisons that contain NaN + if !description.to_ascii_lowercase().contains("nan") && !description.contains("decq541") { + assert_eq!( + bson_to_native_cb, bson_to_native_cb_serde, + "{}", + description + ); + + assert_eq!( + bson_to_native_cb, native_to_native_cb_serde, + "{}", + description + ); + } + + // native_to_bson( bson_to_native(dB) ) = cB + + if let Some(db) = valid.degenerate_bson { + let db = hex::decode(&db).expect(&description); + + let bson_to_native_db = Document::from_reader(db.as_slice()).expect(&description); + let mut native_to_bson_bson_to_native_db = Vec::new(); + bson_to_native_db + .to_writer(&mut native_to_bson_bson_to_native_db) + .unwrap(); + assert_eq!( + hex::encode(native_to_bson_bson_to_native_db).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + + let bson_to_native_db_serde: Document = + crate::from_reader(db.as_slice()).expect(&description); + let mut native_to_bson_bson_to_native_db_serde = Vec::new(); + bson_to_native_db_serde + .to_writer(&mut native_to_bson_bson_to_native_db_serde) + .unwrap(); + assert_eq!( + hex::encode(native_to_bson_bson_to_native_db_serde).to_lowercase(), + valid.canonical_bson.to_lowercase(), + "{}", + description, + ); + + // NaN == NaN is false, so we skip document comparisons that contain NaN + if !description.contains("NaN") { + assert_eq!( + bson_to_native_db_serde, bson_to_native_cb, + "{}", + description + ); + } + } + + // TODO RUST-36: Enable decimal128 tests. + // extJSON not implemented for decimal128 without the feature flag, so we must stop here. + if test.bson_type == "0x13" && !cfg!(feature = "decimal128") { + continue; + } + + let cej: serde_json::Value = + serde_json::from_str(&valid.canonical_extjson).expect(&description); + // native_to_canonical_extended_json( bson_to_native(cB) ) = cEJ let mut cej_updated_float = cej.clone(); @@ -172,26 +255,6 @@ fn run_test(test: TestFile) { } } - // native_to_bson( bson_to_native(dB) ) = cB - - if let Some(db) = valid.degenerate_bson { - let bson_to_native_db = - Document::from_reader(&mut hex::decode(&db).expect(&description).as_slice()) - .expect(&description); - - let mut native_to_bson_bson_to_native_db = Vec::new(); - bson_to_native_db - .to_writer(&mut native_to_bson_bson_to_native_db) - .unwrap(); - - assert_eq!( - hex::encode(native_to_bson_bson_to_native_db).to_lowercase(), - valid.canonical_bson.to_lowercase(), - "{}", - description, - ); - } - if let Some(ref degenerate_extjson) = valid.degenerate_extjson { let dej: serde_json::Value = serde_json::from_str(degenerate_extjson).expect(&description); @@ -252,7 +315,7 @@ fn run_test(test: TestFile) { } } - for decode_error in test.decode_errors { + for decode_error in test.decode_errors.iter() { // No meaningful definition of "byte count" for an arbitrary reader. if decode_error.description == "Stated length less than byte count, with garbage after envelope" @@ -260,8 +323,22 @@ fn run_test(test: TestFile) { continue; } - let bson = hex::decode(decode_error.bson).expect("should decode from hex"); - Document::from_reader(&mut bson.as_slice()).expect_err(decode_error.description.as_str()); + let description = format!( + "{} decode error: {}", + test.bson_type, decode_error.description + ); + let bson = hex::decode(&decode_error.bson).expect("should decode from hex"); + Document::from_reader(bson.as_slice()).expect_err(&description); + crate::from_reader::<_, Document>(bson.as_slice()).expect_err(description.as_str()); + + if decode_error.description.contains("invalid UTF-8") { + let d = crate::from_reader_utf8_lossy::<_, Document>(bson.as_slice()) + .unwrap_or_else(|_| panic!("{}: utf8_lossy should not fail", description)); + if let Some(ref key) = test.test_key { + d.get_str(key) + .unwrap_or_else(|_| panic!("{}: value should be a string", description)); + } + } } for parse_error in test.parse_errors {