diff --git a/pw_protobuf/pw_protobuf_test_protos/full_test.proto b/pw_protobuf/pw_protobuf_test_protos/full_test.proto index ccb0d1e234..eca7ac1886 100644 --- a/pw_protobuf/pw_protobuf_test_protos/full_test.proto +++ b/pw_protobuf/pw_protobuf_test_protos/full_test.proto @@ -142,3 +142,9 @@ message KeyValuePair { string key = 1; string value = 2; } + +// Corner cases of code generation. +message CornerCases { + // Generates ReadUint32() and WriteUint32() that call the parent definition. + uint32 _uint32 = 1; +} diff --git a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py index f23b2c2fe2..5c4757b789 100644 --- a/pw_protobuf/py/pw_protobuf/codegen_pwpb.py +++ b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py @@ -88,12 +88,8 @@ def debug_print(*args, **kwargs): class ProtoMember(abc.ABC): """Base class for a C++ class member for a field in a protobuf message.""" - def __init__( - self, - field: ProtoMessageField, - scope: ProtoNode, - root: ProtoNode, - ): + def __init__(self, field: ProtoMessageField, scope: ProtoNode, + root: ProtoNode): """Creates an instance of a class member. Args: @@ -135,6 +131,11 @@ def _relative_type_namespace(self, from_root: bool = False) -> str: class ProtoMethod(ProtoMember): """Base class for a C++ method for a field in a protobuf message.""" + def __init__(self, field: ProtoMessageField, scope: ProtoNode, + root: ProtoNode, base_class: str): + super().__init__(field, scope, root) + self._base_class: str = base_class + @abc.abstractmethod def params(self) -> List[Tuple[str, str]]: """Returns the parameters of the method as a list of (type, name) pairs. @@ -198,8 +199,9 @@ def return_type(self, from_root: bool = False) -> str: def body(self) -> List[str]: params = ', '.join([pair[1] for pair in self.params()]) - line = 'return {}({}, {});'.format(self._encoder_fn(), - self.field_cast(), params) + line = 'return {}::{}({}, {});'.format(self._base_class, + self._encoder_fn(), + self.field_cast(), params) return [line] def params(self) -> List[Tuple[str, str]]: @@ -272,7 +274,8 @@ def body(self) -> List[str]: def _decoder_body(self) -> List[str]: """Returns the decoder body part as a list of source code lines.""" params = ', '.join([pair[1] for pair in self.params()]) - line = 'return {}({});'.format(self._decoder_fn(), params) + line = 'return {}::{}({});'.format(self._base_class, + self._decoder_fn(), params) return [line] def _decoder_fn(self) -> str: @@ -442,8 +445,9 @@ def params(self) -> List[Tuple[str, str]]: return [] def body(self) -> List[str]: - line = 'return {}::StreamEncoder(GetNestedEncoder({}));'.format( - self._relative_type_namespace(), self.field_cast()) + line = 'return {}::StreamEncoder({}::GetNestedEncoder({}));'.format( + self._relative_type_namespace(), self._base_class, + self.field_cast()) return [line] # Submessage methods are not defined within the class itself because the @@ -1470,8 +1474,9 @@ def params(self) -> List[Tuple[str, str]]: return [(self._relative_type_namespace(), 'value')] def body(self) -> List[str]: - line = 'return WriteUint32(' \ - '{}, static_cast(value));'.format(self.field_cast()) + line = ('return {}::WriteUint32({}, ' + 'static_cast(value));'.format( + self._base_class, self.field_cast())) return [line] def in_class_definition(self) -> bool: @@ -1489,9 +1494,10 @@ def params(self) -> List[Tuple[str, str]]: def body(self) -> List[str]: value_param = self.params()[0][1] - line = (f'return WritePackedUint32({self.field_cast()}, std::span(' - f'reinterpret_cast({value_param}.data()), ' - f'{value_param}.size()));') + line = ( + f'return {self._base_class}::WritePackedUint32(' + f'{self.field_cast()}, std::span(reinterpret_cast(' + f'{value_param}.data()), {value_param}.size()));') return [line] def in_class_definition(self) -> bool: @@ -1767,7 +1773,7 @@ def generate_class_for_message(message: ProtoMessage, root: ProtoNode, # Generate methods for each of the message's fields. for field in message.fields(): for method_class in proto_field_methods(class_type, field.type()): - method = method_class(field, message, root) + method = method_class(field, message, root, base_class) if not method.should_appear(): continue @@ -1797,9 +1803,12 @@ def define_not_in_class_methods(message: ProtoMessage, root: ProtoNode, """Defines methods for a message class that were previously declared.""" assert message.type() == ProtoNode.Type.MESSAGE + base_class_name = class_type.base_class_name() + base_class = f'{PROTOBUF_NAMESPACE}::{base_class_name}' + for field in message.fields(): for method_class in proto_field_methods(class_type, field.type()): - method = method_class(field, message, root) + method = method_class(field, message, root, base_class) if not method.should_appear() or method.in_class_definition(): continue