From 85c698a288bd2ae94cde44109f853214fb00915f Mon Sep 17 00:00:00 2001 From: Casper Meijn Date: Mon, 20 May 2024 18:28:46 +0200 Subject: [PATCH] feat: derive Copy trait for messages where possible (#950) * feat: derive Copy trait for messages where possible Rust primitive types can be copied by simply copying the bits. Rust structs can also have this property by deriving the Copy trait. Automatically derive Copy for: - messages that only have fields with primitive types - the Rust enum for one-of fields - messages whose field type are messages that also implement Copy Generated code for Protobuf enums already derives Copy. * fix: Remove clone call when copy is implemented Clippy reports: warning: using `clone` on type `Timestamp` which implements the `Copy` trait --- prost-build/src/code_generator.rs | 15 ++++- .../_expected_field_attributes.rs | 4 +- .../_expected_field_attributes_formatted.rs | 4 +- prost-build/src/message_graph.rs | 59 +++++++++++++++++-- prost-types/src/datetime.rs | 2 +- prost-types/src/duration.rs | 6 +- prost-types/src/protobuf.rs | 8 +-- prost-types/src/timestamp.rs | 11 ++-- tests/src/build.rs | 4 ++ tests/src/derive_copy.proto | 51 ++++++++++++++++ tests/src/derive_copy.rs | 21 +++++++ tests/src/lib.rs | 2 + 12 files changed, 163 insertions(+), 24 deletions(-) create mode 100644 tests/src/derive_copy.proto create mode 100644 tests/src/derive_copy.rs diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 6ca8581ab..65ee53d71 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -231,7 +231,12 @@ impl<'a> CodeGenerator<'a> { self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Message)]\n", + "#[derive(Clone, {}PartialEq, {}::Message)]\n", + if self.message_graph.can_message_derive_copy(&fq_message_name) { + "Copy, " + } else { + "" + }, prost_path(self.config) )); self.append_skip_debug(&fq_message_name); @@ -613,8 +618,14 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf .push_str("#[allow(clippy::derive_partial_eq_without_eq)]\n"); + + let can_oneof_derive_copy = fields.iter().map(|(field, _idx)| field).all(|field| { + self.message_graph + .can_field_derive_copy(fq_message_name, field) + }); self.buf.push_str(&format!( - "#[derive(Clone, PartialEq, {}::Oneof)]\n", + "#[derive(Clone, {}PartialEq, {}::Oneof)]\n", + if can_oneof_derive_copy { "Copy, " } else { "" }, prost_path(self.config) )); self.append_skip_debug(fq_message_name); diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs index 95fb05d86..f58bbb0ba 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes.rs @@ -23,12 +23,12 @@ pub struct Foo { pub foo: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag="1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Qux { } diff --git a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs index f1eaee751..0aabf753f 100644 --- a/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs +++ b/prost-build/src/fixtures/field_attributes/_expected_field_attributes_formatted.rs @@ -23,11 +23,11 @@ pub struct Foo { pub foo: ::prost::alloc::string::String, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Bar { #[prost(message, optional, boxed, tag = "1")] pub qux: ::core::option::Option<::prost::alloc::boxed::Box>, } #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Qux {} diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index ac0ad1523..9cc40f975 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -4,7 +4,10 @@ use petgraph::algo::has_path_connecting; use petgraph::graph::NodeIndex; use petgraph::Graph; -use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; +use prost_types::{ + field_descriptor_proto::{Label, Type}, + DescriptorProto, FieldDescriptorProto, FileDescriptorProto, +}; /// `MessageGraph` builds a graph of messages whose edges correspond to nesting. /// The goal is to recognize when message types are recursively nested, so @@ -12,6 +15,7 @@ use prost_types::{field_descriptor_proto, DescriptorProto, FileDescriptorProto}; pub struct MessageGraph { index: HashMap, graph: Graph, + messages: HashMap, } impl MessageGraph { @@ -21,6 +25,7 @@ impl MessageGraph { let mut msg_graph = MessageGraph { index: HashMap::new(), graph: Graph::new(), + messages: HashMap::new(), }; for file in files { @@ -41,6 +46,7 @@ impl MessageGraph { let MessageGraph { ref mut index, ref mut graph, + .. } = *self; assert_eq!(b'.', msg_name.as_bytes()[0]); *index @@ -58,13 +64,12 @@ impl MessageGraph { let msg_index = self.get_or_insert_index(msg_name.clone()); for field in &msg.field { - if field.r#type() == field_descriptor_proto::Type::Message - && field.label() != field_descriptor_proto::Label::Repeated - { + if field.r#type() == Type::Message && field.label() != Label::Repeated { let field_index = self.get_or_insert_index(field.type_name.clone().unwrap()); self.graph.add_edge(msg_index, field_index, ()); } } + self.messages.insert(msg_name.clone(), msg.clone()); for msg in &msg.nested_type { self.add_message(&msg_name, msg); @@ -84,4 +89,50 @@ impl MessageGraph { has_path_connecting(&self.graph, outer, inner, None) } + + /// Returns `true` if this message can automatically derive Copy trait. + pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool { + assert_eq!(".", &fq_message_name[..1]); + let msg = self.messages.get(fq_message_name).unwrap(); + msg.field + .iter() + .all(|field| self.can_field_derive_copy(fq_message_name, field)) + } + + /// Returns `true` if the type of this field allows deriving the Copy trait. + pub fn can_field_derive_copy( + &self, + fq_message_name: &str, + field: &FieldDescriptorProto, + ) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + if field.label() == Label::Repeated { + false + } else if field.r#type() == Type::Message { + if self.is_nested(field.type_name(), fq_message_name) { + false + } else { + self.can_message_derive_copy(field.type_name()) + } + } else { + matches!( + field.r#type(), + Type::Float + | Type::Double + | Type::Int32 + | Type::Int64 + | Type::Uint32 + | Type::Uint64 + | Type::Sint32 + | Type::Sint64 + | Type::Fixed32 + | Type::Fixed64 + | Type::Sfixed32 + | Type::Sfixed64 + | Type::Bool + | Type::Enum + ) + } + } } diff --git a/prost-types/src/datetime.rs b/prost-types/src/datetime.rs index 2435ffe73..9c3b3f37b 100644 --- a/prost-types/src/datetime.rs +++ b/prost-types/src/datetime.rs @@ -614,7 +614,7 @@ mod tests { }; assert_eq!( expected, - format!("{}", DateTime::from(timestamp.clone())), + format!("{}", DateTime::from(timestamp)), "timestamp: {:?}", timestamp ); diff --git a/prost-types/src/duration.rs b/prost-types/src/duration.rs index 600716933..3d01f1df4 100644 --- a/prost-types/src/duration.rs +++ b/prost-types/src/duration.rs @@ -105,7 +105,7 @@ impl TryFrom for time::Duration { impl fmt::Display for Duration { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut d = self.clone(); + let mut d = *self; d.normalize(); if self.seconds < 0 && self.nanos < 0 { write!(f, "-")?; @@ -193,7 +193,7 @@ mod tests { Ok(duration) => duration, Err(_) => return Err(TestCaseError::reject("duration out of range")), }; - prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration); + prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration); if std_duration != time::Duration::default() { let neg_prost_duration = Duration { @@ -220,7 +220,7 @@ mod tests { Ok(duration) => duration, Err(_) => return Err(TestCaseError::reject("duration out of range")), }; - prop_assert_eq!(time::Duration::try_from(prost_duration.clone()).unwrap(), std_duration); + prop_assert_eq!(time::Duration::try_from(prost_duration).unwrap(), std_duration); if std_duration != time::Duration::default() { let neg_prost_duration = Duration { diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index edc1361be..34de0ec88 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -94,7 +94,7 @@ pub mod descriptor_proto { /// fields or extension ranges in the same message. Reserved ranges may /// not overlap. #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct ReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -360,7 +360,7 @@ pub mod enum_descriptor_proto { /// is inclusive such that it can appropriately represent the entire int32 /// domain. #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct EnumReservedRange { /// Inclusive. #[prost(int32, optional, tag = "1")] @@ -1853,7 +1853,7 @@ pub struct Mixin { /// be expressed in JSON format as "3.000000001s", and 3 seconds and 1 /// microsecond should be expressed in JSON format as "3.000001s". #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Duration { /// Signed seconds of the span of time. Must be from -315,576,000,000 /// to +315,576,000,000 inclusive. Note: these bounds are computed from: @@ -2293,7 +2293,7 @@ impl NullValue { /// the time format spec '%Y-%m-%dT%H:%M:%S.%fZ'. Likewise, in Java, one can use /// the Joda Time's [`ISODateTimeFormat.dateTime()`]() to obtain a formatter capable of generating timestamps in this format. #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, ::prost::Message)] pub struct Timestamp { /// Represents seconds of UTC time since Unix epoch /// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to diff --git a/prost-types/src/timestamp.rs b/prost-types/src/timestamp.rs index 216712a4b..46cf4044d 100644 --- a/prost-types/src/timestamp.rs +++ b/prost-types/src/timestamp.rs @@ -50,7 +50,7 @@ impl Timestamp { /// /// [1]: https://github.com/google/protobuf/blob/v3.3.2/src/google/protobuf/util/time_util.cc#L59-L77 pub fn try_normalize(mut self) -> Result { - let before = self.clone(); + let before = self; self.normalize(); // If the seconds value has changed, and is either i64::MIN or i64::MAX, then the timestamp // normalization overflowed. @@ -201,7 +201,7 @@ impl TryFrom for std::time::SystemTime { type Error = TimestampError; fn try_from(mut timestamp: Timestamp) -> Result { - let orig_timestamp = timestamp.clone(); + let orig_timestamp = timestamp; timestamp.normalize(); let system_time = if timestamp.seconds >= 0 { @@ -211,8 +211,7 @@ impl TryFrom for std::time::SystemTime { timestamp .seconds .checked_neg() - .ok_or_else(|| TimestampError::OutOfSystemRange(timestamp.clone()))? - as u64, + .ok_or(TimestampError::OutOfSystemRange(timestamp))? as u64, )) }; @@ -234,7 +233,7 @@ impl FromStr for Timestamp { impl fmt::Display for Timestamp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - datetime::DateTime::from(self.clone()).fmt(f) + datetime::DateTime::from(*self).fmt(f) } } #[cfg(test)] @@ -262,7 +261,7 @@ mod tests { ) { let mut timestamp = Timestamp { seconds, nanos }; timestamp.normalize(); - if let Ok(system_time) = SystemTime::try_from(timestamp.clone()) { + if let Ok(system_time) = SystemTime::try_from(timestamp) { prop_assert_eq!(Timestamp::from(system_time), timestamp); } } diff --git a/tests/src/build.rs b/tests/src/build.rs index 0403d2603..68f2a7aeb 100644 --- a/tests/src/build.rs +++ b/tests/src/build.rs @@ -91,6 +91,10 @@ fn main() { .compile_protos(&[src.join("deprecated_field.proto")], includes) .unwrap(); + config + .compile_protos(&[src.join("derive_copy.proto")], includes) + .unwrap(); + config .compile_protos(&[src.join("default_string_escape.proto")], includes) .unwrap(); diff --git a/tests/src/derive_copy.proto b/tests/src/derive_copy.proto new file mode 100644 index 000000000..d2a472bf8 --- /dev/null +++ b/tests/src/derive_copy.proto @@ -0,0 +1,51 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +package derive_copy; + +message EmptyMsg {} + +message IntegerMsg { + int32 field1 = 1; + int64 field2 = 2; + uint32 field3 = 3; + uint64 field4 = 4; + sint32 field5 = 5; + sint64 field6 = 6; + fixed32 field7 = 7; + fixed64 field8 = 8; + sfixed32 field9 = 9; + sfixed64 field10 = 10; +} + +message FloatMsg { + double field1 = 1; + float field2 = 2; +} + +message BoolMsg { bool field1 = 1; } + +enum AnEnum { + A = 0; + B = 1; +}; + +message EnumMsg { AnEnum field1 = 1; } + +message OneOfMsg { + oneof data { + int32 field1 = 1; + int64 field2 = 2; + } +} + +message ComposedMsg { + IntegerMsg field1 = 1; + EnumMsg field2 = 2; + OneOfMsg field3 = 3; +} + +message WellKnownMsg { + google.protobuf.Timestamp timestamp = 1; +} diff --git a/tests/src/derive_copy.rs b/tests/src/derive_copy.rs new file mode 100644 index 000000000..33b4fc84f --- /dev/null +++ b/tests/src/derive_copy.rs @@ -0,0 +1,21 @@ +include!(concat!(env!("OUT_DIR"), "/derive_copy.rs")); + +trait TestCopyIsImplemented: Copy {} + +impl TestCopyIsImplemented for EmptyMsg {} + +impl TestCopyIsImplemented for IntegerMsg {} + +impl TestCopyIsImplemented for FloatMsg {} + +impl TestCopyIsImplemented for BoolMsg {} + +impl TestCopyIsImplemented for AnEnum {} + +impl TestCopyIsImplemented for EnumMsg {} + +impl TestCopyIsImplemented for OneOfMsg {} + +impl TestCopyIsImplemented for ComposedMsg {} + +impl TestCopyIsImplemented for WellKnownMsg {} diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 6674ddddd..5022fbae0 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -36,6 +36,8 @@ mod debug; #[cfg(test)] mod deprecated_field; #[cfg(test)] +mod derive_copy; +#[cfg(test)] mod enum_keyword_variant; #[cfg(test)] mod generic_derive;