diff --git a/pw_protobuf/public/pw_protobuf/internal/codegen.h b/pw_protobuf/public/pw_protobuf/internal/codegen.h index dfc919b858..5d4287c7b8 100644 --- a/pw_protobuf/public/pw_protobuf/internal/codegen.h +++ b/pw_protobuf/public/pw_protobuf/internal/codegen.h @@ -149,7 +149,7 @@ static_assert(sizeof(MessageField) <= sizeof(size_t) * 4, // a field. template union Callback { - Callback() : encode_() {} + constexpr Callback() : encode_() {} ~Callback() { encode_ = nullptr; } // Set the encoder callback. diff --git a/pw_protobuf/py/pw_protobuf/proto_tree.py b/pw_protobuf/py/pw_protobuf/proto_tree.py index a931cad31d..f0b5b3d53a 100644 --- a/pw_protobuf/py/pw_protobuf/proto_tree.py +++ b/pw_protobuf/py/pw_protobuf/proto_tree.py @@ -78,6 +78,14 @@ def proto_path(self) -> str: path = '.'.join(self._attr_hierarchy(lambda node: node.name(), None)) return path.lstrip('.') + def pwpb_struct(self) -> str: + """Name of the pw_protobuf struct for this proto.""" + return '::' + self.cpp_namespace() + '::Message' + + def pwpb_table(self) -> str: + """Name of the pw_protobuf table constant for this proto.""" + return '::' + self.cpp_namespace() + '::kMessageFields' + def nanopb_fields(self) -> str: """Name of the Nanopb variable that represents the proto fields.""" return self._nanopb_name() + '_fields' diff --git a/pw_protobuf_compiler/BUILD.bazel b/pw_protobuf_compiler/BUILD.bazel index 045285e843..2cf855d801 100644 --- a/pw_protobuf_compiler/BUILD.bazel +++ b/pw_protobuf_compiler/BUILD.bazel @@ -18,12 +18,17 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) -# TODO(frolv): Figure out how to support nanopb codegen in Bazel. +# TODO(frolv): Figure out how to support nanopb and pwpb codegen in Bazel. filegroup( name = "nanopb_test", srcs = ["nanopb_test.cc"], ) +filegroup( + name = "pwpb_test", + srcs = ["pwpb_test.cc"], +) + py_proto_library( name = "pw_protobuf_compiler_protos", srcs = [ diff --git a/pw_protobuf_compiler/BUILD.gn b/pw_protobuf_compiler/BUILD.gn index 62f2819cc3..13a8f2b41f 100644 --- a/pw_protobuf_compiler/BUILD.gn +++ b/pw_protobuf_compiler/BUILD.gn @@ -25,7 +25,10 @@ pw_doc_group("docs") { } pw_test_group("tests") { - tests = [ ":nanopb_test" ] + tests = [ + ":nanopb_test", + ":pwpb_test", + ] } pw_test("nanopb_test") { @@ -42,6 +45,16 @@ pw_proto_library("nanopb_test_protos") { } } +pw_test("pwpb_test") { + deps = [ ":pwpb_test_protos.pwpb" ] + sources = [ "pwpb_test.cc" ] +} + +pw_proto_library("pwpb_test_protos") { + sources = [ "pw_protobuf_compiler_pwpb_protos/pwpb_test.proto" ] + inputs = [ "pw_protobuf_compiler_pwpb_protos/pwpb_test.options" ] +} + pw_proto_library("test_protos") { sources = [ "pw_protobuf_compiler_protos/nested/more_nesting/test.proto", diff --git a/pw_protobuf_compiler/CMakeLists.txt b/pw_protobuf_compiler/CMakeLists.txt index 4b5a3b62ac..0a1ddfdfa7 100644 --- a/pw_protobuf_compiler/CMakeLists.txt +++ b/pw_protobuf_compiler/CMakeLists.txt @@ -15,6 +15,22 @@ include($ENV{PW_ROOT}/pw_build/pigweed.cmake) include($ENV{PW_ROOT}/pw_protobuf_compiler/proto.cmake) +pw_proto_library(pw_protobuf_compiler.pwpb_test_protos + SOURCES + pw_protobuf_compiler_pwpb_protos/pwpb_test.proto + INPUTS + pw_protobuf_compiler_pwpb_protos/pwpb_test.options +) + +pw_add_test(pw_protobuf_compiler.pwpb_test + SOURCES + pwpb_test.cc + DEPS + pw_protobuf_compiler.pwpb_test_protos.pwpb + GROUPS + pw_protobuf_compiler +) + if(NOT "${dir_pw_third_party_nanopb}" STREQUAL "") pw_proto_library(pw_protobuf_compiler.nanopb_test_protos SOURCES diff --git a/pw_protobuf_compiler/docs.rst b/pw_protobuf_compiler/docs.rst index 31160e865f..48053d622d 100644 --- a/pw_protobuf_compiler/docs.rst +++ b/pw_protobuf_compiler/docs.rst @@ -15,6 +15,9 @@ Protobuf code generation is currently supported for the following generators: +-------------+----------------+-----------------------------------------------+ | pw_protobuf | ``pwpb`` | Compiles using ``pw_protobuf``. | +-------------+----------------+-----------------------------------------------+ +| pw_protobuf | ``pwpb_rpc`` | Compiles pw_rpc service and client code for | +| RPC | | ``pw_protobuf``. | ++-------------+----------------+-----------------------------------------------+ | Nanopb | ``nanopb`` | Compiles using Nanopb. The build argument | | | | ``dir_pw_third_party_nanopb`` must be set to | | | | point to a local nanopb installation. | @@ -81,6 +84,7 @@ GN supports the following compiled proto libraries via the specified sub-targets generated by a ``pw_proto_library``. * ``${target_name}.pwpb`` - Generated C++ pw_protobuf code +* ``${target_name}.pwpb_rpc`` - Generated C++ pw_protobuf pw_rpc code * ``${target_name}.nanopb`` - Generated C++ nanopb code (requires Nanopb) * ``${target_name}.nanopb_rpc`` - Generated C++ Nanopb pw_rpc code (requires Nanopb) @@ -330,6 +334,7 @@ CMake supports the following compiled proto libraries via the specified sub-targets generated by a ``pw_proto_library``. * ``${NAME}.pwpb`` - Generated C++ pw_protobuf code +* ``${NAME}.pwpb_rpc`` - Generated C++ pw_protobuf pw_rpc code * ``${NAME}.nanopb`` - Generated C++ nanopb code (requires Nanopb) * ``${NAME}.nanopb_rpc`` - Generated C++ Nanopb pw_rpc code (requires Nanopb) * ``${NAME}.raw_rpc`` - Generated C++ raw pw_rpc code (no protobuf library) @@ -416,7 +421,7 @@ compile them. e.g. name = "my_lib", srcs = ["my/lib.cc"], # This target depends on all generated proto targets - # e.g. name.{pwpb, nanopb, raw_rpc, nanopb_rpc} + # e.g. name.{pwpb, pwpb_rpc, nanopb, raw_rpc, nanopb_rpc} deps = [":my_cc_proto"], ) @@ -440,6 +445,7 @@ Bazel supports the following compiled proto libraries via the specified sub-targets generated by a ``pw_proto_library``. * ``${NAME}.pwpb`` - Generated C++ pw_protobuf code +* ``${NAME}.pwpb_rpc`` - Generated C++ pw_protobuf pw_rpc code * ``${NAME}.nanopb`` - Generated C++ nanopb code * ``${NAME}.raw_rpc`` - Generated C++ raw pw_rpc code (no protobuf library) * ``${NAME}.nanopb_rpc`` - Generated C++ Nanopb pw_rpc code diff --git a/pw_protobuf_compiler/proto.cmake b/pw_protobuf_compiler/proto.cmake index dbe4fe2b66..305e678992 100644 --- a/pw_protobuf_compiler/proto.cmake +++ b/pw_protobuf_compiler/proto.cmake @@ -21,9 +21,9 @@ include_guard(GLOBAL) # # This function also creates libraries for generating pw_rpc code: # +# ${NAME}.pwpb_rpc - generates pw_protobuf pw_rpc code # ${NAME}.nanopb_rpc - generates Nanopb pw_rpc code # ${NAME}.raw_rpc - generates raw pw_rpc (no protobuf library) code -# ${NAME}.pwpb_rpc - (Not implemented) generates pw_protobuf pw_rpc code # # Args: # @@ -105,6 +105,8 @@ function(pw_proto_library NAME) # Create a protobuf target for each supported protobuf library. _pw_pwpb_library( "${NAME}" "${sources}" "${inputs}" "${arg_DEPS}" "${include_file}" "${out_dir}") + _pw_pwpb_rpc_library( + "${NAME}" "${sources}" "${inputs}" "${arg_DEPS}" "${include_file}" "${out_dir}") _pw_raw_rpc_library( "${NAME}" "${sources}" "${inputs}" "${arg_DEPS}" "${include_file}" "${out_dir}") _pw_nanopb_library( @@ -200,6 +202,40 @@ function(_pw_pwpb_library NAME SOURCES INPUTS DEPS INCLUDE_FILE OUT_DIR) add_dependencies("${NAME}.pwpb" "${NAME}._generate.pwpb") endfunction(_pw_pwpb_library) +# Internal function that creates a pwpb_rpc library. +function(_pw_pwpb_rpc_library NAME SOURCES INPUTS DEPS INCLUDE_FILE OUT_DIR) + # Determine the names of the output files. + list(TRANSFORM DEPS APPEND .pwpb_rpc) + + _pw_generate_protos("${NAME}" + pwpb_rpc + "$ENV{PW_ROOT}/pw_rpc/py/pw_rpc/plugin_pwpb.py" + ".rpc.pwpb.h" + "${INCLUDE_FILE}" + "${OUT_DIR}" + "${SOURCES}" + "${INPUTS}" + "${DEPS}" + ) + + # Create the library with the generated source files. + add_library("${NAME}.pwpb_rpc" INTERFACE) + target_include_directories("${NAME}.pwpb_rpc" + INTERFACE + "${OUT_DIR}/pwpb_rpc" + ) + target_link_libraries("${NAME}.pwpb_rpc" + INTERFACE + "${NAME}.pwpb" + pw_build + pw_rpc.pwpb.client + pw_rpc.pwpb.method_union + pw_rpc.server + ${DEPS} + ) + add_dependencies("${NAME}.pwpb_rpc" "${NAME}._generate.pwpb_rpc") +endfunction(_pw_pwpb_rpc_library) + # Internal function that creates a raw_rpc proto library. function(_pw_raw_rpc_library NAME SOURCES INPUTS DEPS INCLUDE_FILE OUT_DIR) list(TRANSFORM DEPS APPEND .raw_rpc) diff --git a/pw_protobuf_compiler/proto.gni b/pw_protobuf_compiler/proto.gni index 63a8c61d7c..9499a07bd0 100644 --- a/pw_protobuf_compiler/proto.gni +++ b/pw_protobuf_compiler/proto.gni @@ -136,6 +136,38 @@ template("_pw_invoke_protoc") { # Generates pw_protobuf C++ code for proto files, creating a source_set of the # generated files. This is internal and should not be used outside of this file. # Use pw_proto_library instead. +template("_pw_pwpb_rpc_proto_library") { + # Create a target which runs protoc configured with the pwpb_rpc plugin to + # generate the C++ proto RPC headers. + _pw_invoke_protoc(target_name) { + forward_variables_from(invoker, "*", _forwarded_vars) + language = "pwpb_rpc" + plugin = "$dir_pw_rpc/py/pw_rpc/plugin_pwpb.py" + python_deps = [ "$dir_pw_rpc/py" ] + } + + # Create a library with the generated source files. + config("$target_name._include_path") { + include_dirs = [ "${invoker.base_out_dir}/pwpb_rpc" ] + visibility = [ ":*" ] + } + + pw_source_set(target_name) { + forward_variables_from(invoker, _forwarded_vars) + public_configs = [ ":$target_name._include_path" ] + deps = [ ":$target_name._gen($pw_protobuf_compiler_TOOLCHAIN)" ] + public_deps = [ + ":${invoker.base_target}.pwpb", + "$dir_pw_protobuf", + "$dir_pw_rpc:server", + "$dir_pw_rpc/pwpb:client_api", + "$dir_pw_rpc/pwpb:server_api", + ] + invoker.deps + public = invoker.outputs + check_includes = false + } +} + template("_pw_pwpb_proto_library") { _pw_invoke_protoc(target_name) { forward_variables_from(invoker, "*", _forwarded_vars) @@ -532,6 +564,22 @@ template("pw_proto_library") { # Enumerate all of the protobuf generator targets. + _pw_pwpb_rpc_proto_library("$target_name.pwpb_rpc") { + forward_variables_from(invoker, _forwarded_vars) + forward_variables_from(_common, "*") + + deps = [] + foreach(dep, _deps) { + _base = get_label_info(dep, "label_no_toolchain") + deps += [ "$_base.pwpb_rpc(" + get_label_info(dep, "toolchain") + ")" ] + } + + outputs = [] + foreach(name, _source_names) { + outputs += [ "$base_out_dir/pwpb_rpc/$_prefix/${name}.rpc.pwpb.h" ] + } + } + _pw_pwpb_proto_library("$target_name.pwpb") { forward_variables_from(invoker, _forwarded_vars) forward_variables_from(_common, "*") @@ -652,6 +700,7 @@ template("pw_proto_library") { # All supported pw_protobuf generators. _protobuf_generators = [ "pwpb", + "pwpb_rpc", "nanopb", "nanopb_rpc", "raw_rpc", diff --git a/pw_protobuf_compiler/pw_proto_library.bzl b/pw_protobuf_compiler/pw_proto_library.bzl index 4ca40ceec9..1c0deeb89d 100644 --- a/pw_protobuf_compiler/pw_proto_library.bzl +++ b/pw_protobuf_compiler/pw_proto_library.bzl @@ -88,6 +88,8 @@ def pw_proto_library(name = "", deps = [], nanopb_options = None): The pw_proto_library generates the following targets in this example: "benchmark_pw_proto.pwpb": C++ library exposing the "benchmark.pwpb.h" header. + "benchmark_pw_proto.pwpb_rpc": C++ library exposing the + "benchmark.rpc.pwpb.h" header. "benchmark_pw_proto.raw_rpc": C++ library exposing the "benchmark.raw_rpc.h" header. "benchmark_pw_proto.nanopb": C++ library exposing the "benchmark.pb.h" @@ -108,9 +110,11 @@ def pw_proto_library(name = "", deps = [], nanopb_options = None): deps = deps, ) - # The rpc.pb.h header depends on the generated nanopb code. + # The rpc.pb.h header depends on the generated nanopb or pwpb code. if info["include_nanopb_dep"]: lib_deps = info["deps"] + [":" + name + ".nanopb"] + elif info["include_pwpb_dep"]: + lib_deps = info["deps"] + [":" + name + ".pwpb"] else: lib_deps = info["deps"] @@ -235,6 +239,18 @@ _pw_proto_library = rule( }, ) +_pw_pwpb_rpc_proto_compiler_aspect = _proto_compiler_aspect("rpc.pwpb.h", "//pw_rpc/py:plugin_pwpb") + +_pw_pwpb_rpc_proto_library = rule( + implementation = _impl_pw_proto_library, + attrs = { + "deps": attr.label_list( + providers = [ProtoInfo], + aspects = [_pw_pwpb_rpc_proto_compiler_aspect], + ), + }, +) + _pw_raw_rpc_proto_compiler_aspect = _proto_compiler_aspect("raw_rpc.pb.h", "//pw_rpc/py:plugin_raw") _pw_raw_rpc_proto_library = rule( @@ -267,6 +283,18 @@ PIGWEED_PLUGIN = { "//pw_protobuf:pw_protobuf", ], "include_nanopb_dep": False, + "include_pwpb_dep": False, + }, + "pwpb_rpc": { + "compiler": _pw_pwpb_rpc_proto_library, + "deps": [ + "//pw_protobuf:pw_protobuf", + "//pw_rpc", + "//pw_rpc/raw:client_api", + "//pw_rpc/raw:server_api", + ], + "include_nanopb_dep": False, + "include_pwpb_dep": True, }, "raw_rpc": { "compiler": _pw_raw_rpc_proto_library, @@ -276,6 +304,7 @@ PIGWEED_PLUGIN = { "//pw_rpc/raw:server_api", ], "include_nanopb_dep": False, + "include_pwpb_dep": False, }, "nanopb_rpc": { "compiler": _pw_nanopb_rpc_proto_library, @@ -285,5 +314,6 @@ PIGWEED_PLUGIN = { "//pw_rpc/nanopb:server_api", ], "include_nanopb_dep": True, + "include_pwpb_dep": False, }, } diff --git a/pw_protobuf_compiler/pw_protobuf_compiler_pwpb_protos/pwpb_test.options b/pw_protobuf_compiler/pw_protobuf_compiler_pwpb_protos/pwpb_test.options new file mode 100644 index 0000000000..7550d85279 --- /dev/null +++ b/pw_protobuf_compiler/pw_protobuf_compiler_pwpb_protos/pwpb_test.options @@ -0,0 +1,15 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +pw.protobuf_compiler.Point.name max_size:16 diff --git a/pw_protobuf_compiler/pw_protobuf_compiler_pwpb_protos/pwpb_test.proto b/pw_protobuf_compiler/pw_protobuf_compiler_pwpb_protos/pwpb_test.proto new file mode 100644 index 0000000000..5faf7c1e0f --- /dev/null +++ b/pw_protobuf_compiler/pw_protobuf_compiler_pwpb_protos/pwpb_test.proto @@ -0,0 +1,22 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +syntax = "proto3"; + +package pw.protobuf_compiler; + +message Point { + uint32 x = 1; + uint32 y = 2; + string name = 3; +}; diff --git a/pw_protobuf_compiler/pwpb_test.cc b/pw_protobuf_compiler/pwpb_test.cc new file mode 100644 index 0000000000..10c3962c42 --- /dev/null +++ b/pw_protobuf_compiler/pwpb_test.cc @@ -0,0 +1,24 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "gtest/gtest.h" +#include "pw_protobuf_compiler_pwpb_protos/pwpb_test.pwpb.h" + +TEST(Pwpb, CompilesProtobufs) { + pw::protobuf_compiler::Point::Message point = {4, 8, "point"}; + EXPECT_EQ(point.x, 4u); + EXPECT_EQ(point.y, 8u); + EXPECT_EQ(point.name.size(), 5u); + EXPECT_EQ(point.name.view(), "point"); +} diff --git a/pw_protobuf_compiler/py/pw_protobuf_compiler/generate_protos.py b/pw_protobuf_compiler/py/pw_protobuf_compiler/generate_protos.py index f7a69360f0..07a49a046f 100644 --- a/pw_protobuf_compiler/py/pw_protobuf_compiler/generate_protos.py +++ b/pw_protobuf_compiler/py/pw_protobuf_compiler/generate_protos.py @@ -65,7 +65,7 @@ def _argument_parser() -> argparse.ArgumentParser: return parser -def protoc_cc_args(args: argparse.Namespace) -> Tuple[str, ...]: +def protoc_pwpb_args(args: argparse.Namespace) -> Tuple[str, ...]: return _COMMON_FLAGS + ( '--plugin', f'protoc-gen-custom={args.plugin_path}', @@ -75,6 +75,15 @@ def protoc_cc_args(args: argparse.Namespace) -> Tuple[str, ...]: ) +def protoc_pwpb_rpc_args(args: argparse.Namespace) -> Tuple[str, ...]: + return _COMMON_FLAGS + ( + '--plugin', + f'protoc-gen-custom={args.plugin_path}', + '--custom_out', + args.out_dir, + ) + + def protoc_go_args(args: argparse.Namespace) -> Tuple[str, ...]: return _COMMON_FLAGS + ( '--go_out', @@ -128,12 +137,13 @@ def protoc_python_args(args: argparse.Namespace) -> Tuple[str, ...]: # Default additional protoc arguments for each supported language. # TODO(frolv): Make these overridable with a command-line argument. DEFAULT_PROTOC_ARGS: Dict[str, _DefaultArgsFunction] = { - 'pwpb': protoc_cc_args, 'go': protoc_go_args, 'nanopb': protoc_nanopb_args, 'nanopb_rpc': protoc_nanopb_rpc_args, - 'raw_rpc': protoc_raw_rpc_args, + 'pwpb': protoc_pwpb_args, + 'pwpb_rpc': protoc_pwpb_rpc_args, 'python': protoc_python_args, + 'raw_rpc': protoc_raw_rpc_args, } # Languages that protoc internally supports. diff --git a/pw_rpc/BUILD.bazel b/pw_rpc/BUILD.bazel index 8a9aba995f..402da1fb21 100644 --- a/pw_rpc/BUILD.bazel +++ b/pw_rpc/BUILD.bazel @@ -335,6 +335,17 @@ proto_plugin( visibility = ["//visibility:public"], ) +proto_plugin( + name = "pw_cc_plugin_pwpb_rpc", + outputs = [ + "{protopath}.rpc.pwpb.h", + ], + protoc_plugin_name = "pwpb_rpc", + tool = "@pigweed//pw_rpc/py:plugin_pwpb", + use_built_in_shell_environment = True, + visibility = ["//visibility:public"], +) + proto_library( name = "echo_proto", srcs = [ diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn index 572ea95bc9..3dc0d1c57c 100644 --- a/pw_rpc/BUILD.gn +++ b/pw_rpc/BUILD.gn @@ -261,6 +261,8 @@ pw_executable("client_integration_test") { dir_pw_unit_test, ] + deps += [ "pwpb:client_integration_test" ] + if (dir_pw_third_party_nanopb != "") { deps += [ "nanopb:client_integration_test" ] } @@ -321,6 +323,7 @@ pw_doc_group("docs") { ] group_deps = [ "nanopb:docs", + "pwpb:docs", "py:docs", "ts:docs", ] @@ -363,6 +366,7 @@ pw_test_group("tests") { ] group_deps = [ "nanopb:tests", + "pwpb:tests", "raw:tests", ] } diff --git a/pw_rpc/CMakeLists.txt b/pw_rpc/CMakeLists.txt index 7d5b2cf475..022f2827b9 100644 --- a/pw_rpc/CMakeLists.txt +++ b/pw_rpc/CMakeLists.txt @@ -19,6 +19,7 @@ if(NOT "${dir_pw_third_party_nanopb}" STREQUAL "") add_subdirectory(nanopb) endif() +add_subdirectory(pwpb) add_subdirectory(raw) add_subdirectory(system_server) diff --git a/pw_rpc/docs.rst b/pw_rpc/docs.rst index 25554f05a1..71bc1c8598 100644 --- a/pw_rpc/docs.rst +++ b/pw_rpc/docs.rst @@ -43,8 +43,8 @@ Pigweed provides several client and server implementations of ``pw_rpc``. - ✅ - ✅ * - C++ (pw_protobuf) - - planned - - planned + - ✅ + - ✅ * - Java - - in development @@ -166,10 +166,11 @@ This protocol buffer is declared in a ``BUILD.gn`` file as follows: If you need to distinguish between a default-valued field and a missing field, mark the field as ``optional``. The presence of the field can be detected - with a ``HasField(name)`` or ``has_`` member, depending on the library. + with ``std::optional``, a ``HasField(name)``, or ``has_`` member, + depending on the library. - Optional fields have some overhead --- default-valued fields are included in - the encoded proto, and, if using Nanopb, the proto structs have a + Optional fields have some overhead --- if using Nanopb, default-valued fields + are included in the encoded proto, and the proto structs have a ``has_`` flag for each optional field. Use plain fields if field presence detection is not needed. @@ -207,9 +208,9 @@ For example, the generated RPC header for ``"foo_bar/the_service.proto"`` is The generated header defines a base class for each RPC service declared in the ``.proto`` file. A service named ``TheService`` in package ``foo.bar`` would -generate the following base class for Nanopb: +generate the following base class for pw_protobuf: -.. cpp:class:: template foo::bar::pw_rpc::nanopb::TheService::Service +.. cpp:class:: template foo::bar::pw_rpc::pwpb::TheService::Service 3. RPC service definition ------------------------- @@ -230,7 +231,7 @@ Services may mix and match protobuf implementations within one service. .. code-block:: sh - find out/ -name .rpc.pb.h + find out/ -name .rpc.pwpb.h #. Scroll to the bottom of the generated RPC header. #. Copy the stub class declaration to a header file. @@ -239,32 +240,33 @@ Services may mix and match protobuf implementations within one service. #. List these files in a build target with a dependency on the ``pw_proto_library``. -A Nanopb implementation of this service would be as follows: +A pw_protobuf implementation of this service would be as follows: .. code-block:: cpp - #include "foo_bar/the_service.rpc.pb.h" + #include "foo_bar/the_service.rpc.pwpb.h" namespace foo::bar { - class TheService : public pw_rpc::nanopb::TheService::Service { + class TheService : public pw_rpc::pwpb::TheService::Service { public: - pw::Status MethodOne(const foo_bar_Request& request, - foo_bar_Response& response) { + pw::Status MethodOne(const Request::Message& request, + Response::Message& response) { // implementation + response.number = 123; return pw::OkStatus(); } - void MethodTwo(const foo_bar_Request& request, - ServerWriter& response) { + void MethodTwo(const Request::Message& request, + ServerWriter& response) { // implementation - response.Write(foo_bar_Response{.number = 123}); + response.Write({.number = 123}); } }; } // namespace foo::bar -The Nanopb implementation would be declared in a ``BUILD.gn``: +The pw_protobuf implementation would be declared in a ``BUILD.gn``: .. code-block:: python @@ -275,14 +277,9 @@ The Nanopb implementation would be declared in a ``BUILD.gn``: pw_source_set("the_service") { public_configs = [ ":public" ] public = [ "public/foo_bar/service.h" ] - public_deps = [ ":the_service_proto.nanopb_rpc" ] + public_deps = [ ":the_service_proto.pwpb_rpc" ] } -.. attention:: - - pw_rpc's generated classes will support using ``pw_protobuf`` or raw buffers - (no protobuf library) in the future. - 4. Register the service with a server ------------------------------------- This example code sets up an RPC server with an :ref:`HDLC` @@ -401,6 +398,7 @@ Protobuf library APIs .. toctree:: :maxdepth: 1 + pwpb/docs nanopb/docs Testing a pw_rpc integration @@ -413,14 +411,14 @@ working as intended by registering the provided ``EchoService``, defined in :language: protobuf :lines: 14- -For example, in C++ with nanopb: +For example, in C++ with pw_protobuf: .. code:: c++ #include "pw_rpc/server.h" // Include the apporpriate header for your protobuf library. - #include "pw_rpc/echo_service_nanopb.h" + #include "pw_rpc/echo_service_pwpb.h" constexpr pw::rpc::Channel kChannels[] = { /* ... */ }; static pw::rpc::Server server(kChannels); @@ -836,7 +834,7 @@ Example ^^^^^^^ .. code-block:: c++ - #include "pw_rpc/echo_service_nanopb.h" + #include "pw_rpc/echo_service_pwpb.h" namespace { // Generated clients are namespaced with their proto library. @@ -849,7 +847,7 @@ Example // Callback invoked when a response is received. This is called synchronously // from Client::ProcessPacket. - void EchoResponse(const pw_rpc_EchoMessage& response, + void EchoResponse(const EchoMessage::Message& response, pw::Status status) { if (status.ok()) { PW_LOG_INFO("Received echo response: %s", response.msg); @@ -865,7 +863,7 @@ Example // Create a client to call the EchoService. EchoClient echo_client(my_rpc_client, kDefaultChannelId); - pw_rpc_EchoMessage request = pw_rpc_EchoMessage_init_default; + EchoMessage::Message request{}; pw::string::Copy(message, request.msg); // By assigning the returned ClientCall to the global echo_call, the RPC @@ -927,11 +925,13 @@ what packets are sent by an RPC client in tests. Both raw and Nanopb interfaces are supported. Code that uses the raw API may be tested with the Nanopb test helpers, and vice versa. -To test code that invokes RPCs, declare a ``RawClientTestContext`` or -``NanopbClientTestContext``. These test context objects provide a -preconfigured RPC client, channel, server fake, and buffer for encoding packets. -These test classes are defined in ``pw_rpc/raw/client_testing.h`` and -``pw_rpc/nanopb/client_testing.h``. +To test code that invokes RPCs, declare a ``RawClientTestContext``, +``PwpbClientTestContext``, or ``NanopbClientTestContext``. These test context +objects provide a preconfigured RPC client, channel, server fake, and buffer for +encoding packets. + +These test classes are defined in ``pw_rpc/raw/client_testing.h``, +``pw_rpc/pwpb/client_testing.h``, or ``pw_rpc/nanopb/client_testing.h``. Use the context's ``client()`` and ``channel()`` to invoke RPCs. Use the context's ``server()`` to simulate responses. To verify that the client sent the diff --git a/pw_rpc/public/pw_rpc/internal/method_lookup.h b/pw_rpc/public/pw_rpc/internal/method_lookup.h index be0745582e..232e0c4dbd 100644 --- a/pw_rpc/public/pw_rpc/internal/method_lookup.h +++ b/pw_rpc/public/pw_rpc/internal/method_lookup.h @@ -34,6 +34,13 @@ class MethodLookup { return method; } + template + static constexpr const auto& GetPwpbMethod() { + const auto& method = GetMethodUnion().pwpb_method(); + static_assert(method.id() == kMethodId, "Incorrect method implementation"); + return method; + } + template static constexpr const auto& GetNanopbMethod() { const auto& method = GetMethodUnion().nanopb_method(); diff --git a/pw_rpc/public/pw_rpc/payloads_view.h b/pw_rpc/public/pw_rpc/payloads_view.h index 3d4934f883..07fa7f3fbf 100644 --- a/pw_rpc/public/pw_rpc/payloads_view.h +++ b/pw_rpc/public/pw_rpc/payloads_view.h @@ -124,6 +124,9 @@ class PayloadsView { template friend class NanopbPayloadsView; + template + friend class PwpbPayloadsView; + template using MethodInfo = internal::MethodInfo; diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h index 69eae07fcc..949737a8e4 100644 --- a/pw_rpc/public/pw_rpc/server.h +++ b/pw_rpc/public/pw_rpc/server.h @@ -92,6 +92,15 @@ class Server : public internal::Endpoint { template friend class NanopbUnaryResponder; + template + friend class PwpbServerReaderWriter; + template + friend class PwpbServerWriter; + template + friend class PwpbServerReader; + template + friend class PwpbUnaryResponder; + // Creates a call context for a particular RPC. Unlike the CallContext // constructor, this function checks the type of RPC at compile time. template + +#include "gtest/gtest.h" +#include "pw_rpc/internal/test_utils.h" +#include "pw_rpc/pwpb/client_reader_writer.h" +#include "pw_rpc_pwpb_private/internal_test_utils.h" +#include "pw_rpc_test_protos/test.pwpb.h" + +PW_MODIFY_DIAGNOSTICS_PUSH(); +PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers"); + +namespace pw::rpc { +namespace { + +using internal::ClientContextForTest; + +constexpr uint32_t kServiceId = 16; +constexpr uint32_t kUnaryMethodId = 111; +constexpr uint32_t kServerStreamingMethodId = 112; + +class FakeGeneratedServiceClient { + public: + static PwpbUnaryReceiver TestUnaryRpc( + Client& client, + uint32_t channel_id, + const test::TestRequest::Message& request, + Function on_response, + Function on_error = nullptr) { + return internal::PwpbUnaryResponseClientCall:: + Start>( + client, + channel_id, + kServiceId, + kUnaryMethodId, + internal::kPwpbMethodSerde<&test::TestRequest::kMessageFields, + &test::TestResponse::kMessageFields>, + std::move(on_response), + std::move(on_error), + request); + } + + static PwpbUnaryReceiver TestAnotherUnaryRpc( + Client& client, + uint32_t channel_id, + const test::TestRequest::Message& request, + Function on_response, + Function on_error = nullptr) { + return internal::PwpbUnaryResponseClientCall:: + Start>( + client, + channel_id, + kServiceId, + kUnaryMethodId, + internal::kPwpbMethodSerde<&test::TestRequest::kMessageFields, + &test::TestResponse::kMessageFields>, + std::move(on_response), + std::move(on_error), + request); + } + + static PwpbClientReader + TestServerStreamRpc( + Client& client, + uint32_t channel_id, + const test::TestRequest::Message& request, + Function on_response, + Function on_stream_end, + Function on_error = nullptr) { + return internal:: + PwpbStreamResponseClientCall::Start< + PwpbClientReader>( + client, + channel_id, + kServiceId, + kServerStreamingMethodId, + internal::kPwpbMethodSerde< + &test::TestRequest::kMessageFields, + &test::TestStreamResponse::kMessageFields>, + std::move(on_response), + std::move(on_stream_end), + std::move(on_error), + request); + } +}; + +TEST(PwpbClientCall, Unary_SendsRequestPacket) { + ClientContextForTest context; + + auto call = FakeGeneratedServiceClient::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + nullptr); + + EXPECT_EQ(context.output().total_packets(), 1u); + auto packet = context.output().last_packet(); + EXPECT_EQ(packet.channel_id(), context.channel().id()); + EXPECT_EQ(packet.service_id(), kServiceId); + EXPECT_EQ(packet.method_id(), kUnaryMethodId); + + PW_DECODE_PB(test::TestRequest, sent_proto, packet.payload()); + EXPECT_EQ(sent_proto.integer, 123); +} + +class UnaryClientCall : public ::testing::Test { + protected: + std::optional last_status_; + std::optional last_error_; + int responses_received_ = 0; + int last_response_value_ = 0; +}; + +TEST_F(UnaryClientCall, InvokesCallbackOnValidResponse) { + ClientContextForTest context; + + auto call = FakeGeneratedServiceClient::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + [this](const test::TestResponse::Message& response, Status status) { + ++responses_received_; + last_status_ = status; + last_response_value_ = response.value; + }); + + PW_ENCODE_PB(test::TestResponse, response, .value = 42); + EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response)); + + ASSERT_EQ(responses_received_, 1); + EXPECT_EQ(last_status_, OkStatus()); + EXPECT_EQ(last_response_value_, 42); +} + +TEST_F(UnaryClientCall, DoesNothingOnNullCallback) { + ClientContextForTest context; + + auto call = FakeGeneratedServiceClient::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + nullptr); + + PW_ENCODE_PB(test::TestResponse, response, .value = 42); + EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response)); + + ASSERT_EQ(responses_received_, 0); +} + +TEST_F(UnaryClientCall, InvokesErrorCallbackOnInvalidResponse) { + ClientContextForTest context; + + auto call = FakeGeneratedServiceClient::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + [this](const test::TestResponse::Message& response, Status status) { + ++responses_received_; + last_status_ = status; + last_response_value_ = response.value; + }, + [this](Status status) { last_error_ = status; }); + + constexpr std::byte bad_payload[]{ + std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}}; + EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload)); + + EXPECT_EQ(responses_received_, 0); + ASSERT_TRUE(last_error_.has_value()); + EXPECT_EQ(last_error_, Status::DataLoss()); +} + +TEST_F(UnaryClientCall, InvokesErrorCallbackOnServerError) { + ClientContextForTest context; + + auto call = FakeGeneratedServiceClient::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + [this](const test::TestResponse::Message& response, Status status) { + ++responses_received_; + last_status_ = status; + last_response_value_ = response.value; + }, + [this](Status status) { last_error_ = status; }); + + EXPECT_EQ(OkStatus(), + context.SendPacket(internal::PacketType::SERVER_ERROR, + Status::NotFound())); + + EXPECT_EQ(responses_received_, 0); + EXPECT_EQ(last_error_, Status::NotFound()); +} + +TEST_F(UnaryClientCall, DoesNothingOnErrorWithoutCallback) { + ClientContextForTest context; + + auto call = FakeGeneratedServiceClient::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + [this](const test::TestResponse::Message& response, Status status) { + ++responses_received_; + last_status_ = status; + last_response_value_ = response.value; + }); + + constexpr std::byte bad_payload[]{ + std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}}; + EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload)); + + EXPECT_EQ(responses_received_, 0); +} + +TEST_F(UnaryClientCall, OnlyReceivesOneResponse) { + ClientContextForTest context; + + auto call = FakeGeneratedServiceClient::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + [this](const test::TestResponse::Message& response, Status status) { + ++responses_received_; + last_status_ = status; + last_response_value_ = response.value; + }); + + PW_ENCODE_PB(test::TestResponse, r1, .value = 42); + EXPECT_EQ(OkStatus(), context.SendResponse(Status::Unimplemented(), r1)); + PW_ENCODE_PB(test::TestResponse, r2, .value = 44); + EXPECT_EQ(OkStatus(), context.SendResponse(Status::OutOfRange(), r2)); + PW_ENCODE_PB(test::TestResponse, r3, .value = 46); + EXPECT_EQ(OkStatus(), context.SendResponse(Status::Internal(), r3)); + + EXPECT_EQ(responses_received_, 1); + EXPECT_EQ(last_status_, Status::Unimplemented()); + EXPECT_EQ(last_response_value_, 42); +} + +class ServerStreamingClientCall : public ::testing::Test { + protected: + bool active_ = true; + std::optional stream_status_; + std::optional rpc_error_; + int responses_received_ = 0; + int last_response_number_ = 0; +}; + +TEST_F(ServerStreamingClientCall, SendsRequestPacket) { + ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context; + + auto call = FakeGeneratedServiceClient::TestServerStreamRpc( + context.client(), + context.channel().id(), + {.integer = 71, .status_code = 0}, + nullptr, + nullptr); + + EXPECT_EQ(context.output().total_packets(), 1u); + auto packet = context.output().last_packet(); + EXPECT_EQ(packet.channel_id(), context.channel().id()); + EXPECT_EQ(packet.service_id(), kServiceId); + EXPECT_EQ(packet.method_id(), kServerStreamingMethodId); + + PW_DECODE_PB(test::TestRequest, sent_proto, packet.payload()); + EXPECT_EQ(sent_proto.integer, 71); +} + +TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) { + ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context; + + auto call = FakeGeneratedServiceClient::TestServerStreamRpc( + context.client(), + context.channel().id(), + {.integer = 71, .status_code = 0}, + [this](const test::TestStreamResponse::Message& response) { + ++responses_received_; + last_response_number_ = response.number; + }, + [this](Status status) { + active_ = false; + stream_status_ = status; + }); + + PW_ENCODE_PB(test::TestStreamResponse, r1, .chunk = {}, .number = 11u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r1)); + EXPECT_TRUE(active_); + EXPECT_EQ(responses_received_, 1); + EXPECT_EQ(last_response_number_, 11); + + PW_ENCODE_PB(test::TestStreamResponse, r2, .chunk = {}, .number = 22u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r2)); + EXPECT_TRUE(active_); + EXPECT_EQ(responses_received_, 2); + EXPECT_EQ(last_response_number_, 22); + + PW_ENCODE_PB(test::TestStreamResponse, r3, .chunk = {}, .number = 33u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r3)); + EXPECT_TRUE(active_); + EXPECT_EQ(responses_received_, 3); + EXPECT_EQ(last_response_number_, 33); +} + +TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) { + ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context; + + auto call = FakeGeneratedServiceClient::TestServerStreamRpc( + context.client(), + context.channel().id(), + {.integer = 71, .status_code = 0}, + [this](const test::TestStreamResponse::Message& response) { + ++responses_received_; + last_response_number_ = response.number; + }, + [this](Status status) { + active_ = false; + stream_status_ = status; + }); + + PW_ENCODE_PB(test::TestStreamResponse, r1, .chunk = {}, .number = 11u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r1)); + EXPECT_TRUE(active_); + + PW_ENCODE_PB(test::TestStreamResponse, r2, .chunk = {}, .number = 22u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r2)); + EXPECT_TRUE(active_); + + // Close the stream. + EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound())); + + PW_ENCODE_PB(test::TestStreamResponse, r3, .chunk = {}, .number = 33u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r3)); + EXPECT_FALSE(active_); + + EXPECT_EQ(responses_received_, 2); +} + +TEST_F(ServerStreamingClientCall, InvokesErrorCallbackOnInvalidResponses) { + ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context; + + auto call = FakeGeneratedServiceClient::TestServerStreamRpc( + context.client(), + context.channel().id(), + {.integer = 71, .status_code = 0}, + [this](const test::TestStreamResponse::Message& response) { + ++responses_received_; + last_response_number_ = response.number; + }, + nullptr, + [this](Status error) { rpc_error_ = error; }); + + PW_ENCODE_PB(test::TestStreamResponse, r1, .chunk = {}, .number = 11u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r1)); + EXPECT_TRUE(active_); + EXPECT_EQ(responses_received_, 1); + EXPECT_EQ(last_response_number_, 11); + + constexpr std::byte bad_payload[]{ + std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}}; + EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload)); + EXPECT_EQ(responses_received_, 1); + ASSERT_TRUE(rpc_error_.has_value()); + EXPECT_EQ(rpc_error_, Status::DataLoss()); + + PW_ENCODE_PB(test::TestStreamResponse, r2, .chunk = {}, .number = 22u); + EXPECT_EQ(OkStatus(), context.SendServerStream(r2)); + EXPECT_TRUE(active_); + EXPECT_EQ(responses_received_, 2); + EXPECT_EQ(last_response_number_, 22); + + EXPECT_EQ(OkStatus(), + context.SendPacket(internal::PacketType::SERVER_ERROR, + Status::NotFound())); + EXPECT_EQ(responses_received_, 2); + EXPECT_EQ(rpc_error_, Status::NotFound()); +} + +} // namespace +} // namespace pw::rpc + +PW_MODIFY_DIAGNOSTICS_POP(); diff --git a/pw_rpc/pwpb/client_integration_test.cc b/pw_rpc/pwpb/client_integration_test.cc new file mode 100644 index 0000000000..e0d55bf1f4 --- /dev/null +++ b/pw_rpc/pwpb/client_integration_test.cc @@ -0,0 +1,158 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "gtest/gtest.h" +#include "pw_assert/check.h" +#include "pw_rpc/benchmark.rpc.pwpb.h" +#include "pw_rpc/integration_testing.h" +#include "pw_sync/binary_semaphore.h" + +namespace pwpb_rpc_test { +namespace { + +using namespace std::chrono_literals; +using pw::ByteSpan; +using pw::ConstByteSpan; +using pw::Function; +using pw::OkStatus; +using pw::Status; + +using pw::rpc::pw_rpc::pwpb::Benchmark; + +constexpr int kIterations = 10; + +class PayloadReceiver { + public: + const char* Wait() { + PW_CHECK(sem_.try_acquire_for(1500ms)); + return reinterpret_cast(payload_.payload.data()); + } + + Function UnaryOnCompleted() { + return [this](const pw::rpc::Payload::Message& data, Status) { + CopyPayload(data); + }; + } + + Function OnNext() { + return [this](const pw::rpc::Payload::Message& data) { CopyPayload(data); }; + } + + private: + void CopyPayload(const pw::rpc::Payload::Message& data) { + payload_ = data; + sem_.release(); + } + + pw::sync::BinarySemaphore sem_; + pw::rpc::Payload::Message payload_ = {}; +}; + +template +pw::rpc::Payload::Message Payload(const char (&string)[kSize]) { + static_assert(kSize <= sizeof(pw::rpc::Payload::Message::payload)); + pw::rpc::Payload::Message payload{}; + payload.payload.resize(kSize); + std::memcpy(payload.payload.data(), string, kSize); + return payload; +} + +const Benchmark::Client kClient(pw::rpc::integration_test::client(), + pw::rpc::integration_test::kChannelId); + +TEST(PwpbRpcIntegrationTest, Unary) { + char value[] = {"hello, world!"}; + + for (int i = 0; i < kIterations; ++i) { + PayloadReceiver receiver; + + value[0] = static_cast(i); + pw::rpc::PwpbUnaryReceiver call = + kClient.UnaryEcho(Payload(value), receiver.UnaryOnCompleted()); + ASSERT_STREQ(receiver.Wait(), value); + } +} + +TEST(PwpbRpcIntegrationTest, Unary_ReuseCall) { + pw::rpc::PwpbUnaryReceiver call; + char value[] = {"O_o "}; + + for (int i = 0; i < kIterations; ++i) { + PayloadReceiver receiver; + + value[sizeof(value) - 2] = static_cast(i); + call = kClient.UnaryEcho(Payload(value), receiver.UnaryOnCompleted()); + ASSERT_STREQ(receiver.Wait(), value); + } +} + +TEST(PwpbRpcIntegrationTest, Unary_DiscardCalls) { + constexpr int iterations = PW_RPC_USE_GLOBAL_MUTEX ? 10000 : 1; + for (int i = 0; i < iterations; ++i) { + kClient.UnaryEcho(Payload("O_o")); + } +} + +TEST(PwpbRpcIntegrationTest, BidirectionalStreaming_MoveCalls) { + for (int i = 0; i < kIterations; ++i) { + PayloadReceiver receiver; + pw::rpc::PwpbClientReaderWriter call = + kClient.BidirectionalEcho(receiver.OnNext()); + + ASSERT_EQ(OkStatus(), call.Write(Payload("Yello"))); + ASSERT_STREQ(receiver.Wait(), "Yello"); + + pw::rpc::PwpbClientReaderWriter + new_call = std::move(call); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write(Payload("Dello"))); + + ASSERT_EQ(OkStatus(), new_call.Write(Payload("Dello"))); + ASSERT_STREQ(receiver.Wait(), "Dello"); + + call = std::move(new_call); + + EXPECT_EQ(Status::FailedPrecondition(), new_call.Write(Payload("Dello"))); + + ASSERT_EQ(OkStatus(), call.Write(Payload("???"))); + ASSERT_STREQ(receiver.Wait(), "???"); + + EXPECT_EQ(OkStatus(), call.Cancel()); + EXPECT_EQ(Status::FailedPrecondition(), new_call.Cancel()); + } +} + +TEST(PwpbRpcIntegrationTest, BidirectionalStreaming_ReuseCall) { + pw::rpc::PwpbClientReaderWriter + call; + + for (int i = 0; i < kIterations; ++i) { + PayloadReceiver receiver; + call = kClient.BidirectionalEcho(receiver.OnNext()); + + ASSERT_EQ(OkStatus(), call.Write(Payload("Yello"))); + ASSERT_STREQ(receiver.Wait(), "Yello"); + + ASSERT_EQ(OkStatus(), call.Write(Payload("Dello"))); + ASSERT_STREQ(receiver.Wait(), "Dello"); + + ASSERT_EQ(OkStatus(), call.Write(Payload("???"))); + ASSERT_STREQ(receiver.Wait(), "???"); + } +} + +} // namespace +} // namespace pwpb_rpc_test diff --git a/pw_rpc/pwpb/client_reader_writer_test.cc b/pw_rpc/pwpb/client_reader_writer_test.cc new file mode 100644 index 0000000000..b4d53a39fe --- /dev/null +++ b/pw_rpc/pwpb/client_reader_writer_test.cc @@ -0,0 +1,244 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/pwpb/client_reader_writer.h" + +#include + +#include "gtest/gtest.h" +#include "pw_rpc/pwpb/client_testing.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" + +PW_MODIFY_DIAGNOSTICS_PUSH(); +PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers"); + +namespace pw::rpc { +namespace { + +using test::pw_rpc::pwpb::TestService; + +void FailIfCalled(Status) { FAIL(); } +template +void FailIfOnNextCalled(const T&) { + FAIL(); +} +template +void FailIfOnCompletedCalled(const T&, Status) { + FAIL(); +} + +TEST(PwpbUnaryReceiver, DefaultConstructed) { + PwpbUnaryReceiver call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + + call.set_on_completed([](const test::TestResponse::Message&, Status) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbClientWriter, DefaultConstructed) { + PwpbClientWriter + call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + EXPECT_EQ(Status::FailedPrecondition(), call.CloseClientStream()); + + call.set_on_completed( + [](const test::TestStreamResponse::Message&, Status) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbClientReader, DefaultConstructed) { + PwpbClientReader call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + + call.set_on_completed([](Status) {}); + call.set_on_next([](const test::TestStreamResponse::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbClientReaderWriter, DefaultConstructed) { + PwpbClientReaderWriter + call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + EXPECT_EQ(Status::FailedPrecondition(), call.CloseClientStream()); + + call.set_on_completed([](Status) {}); + call.set_on_next([](const test::TestStreamResponse::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbUnaryReceiver, Closed) { + PwpbClientTestContext ctx; + PwpbUnaryReceiver call = + TestService::TestUnaryRpc( + ctx.client(), + ctx.channel().id(), + {}, + FailIfOnCompletedCalled, + FailIfCalled); + ASSERT_EQ(OkStatus(), call.Cancel()); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + + call.set_on_completed([](const test::TestResponse::Message&, Status) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbClientWriter, Closed) { + PwpbClientTestContext ctx; + PwpbClientWriter + call = TestService::TestClientStreamRpc( + ctx.client(), + ctx.channel().id(), + FailIfOnCompletedCalled, + FailIfCalled); + ASSERT_EQ(OkStatus(), call.Cancel()); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + EXPECT_EQ(Status::FailedPrecondition(), call.CloseClientStream()); + + call.set_on_completed( + [](const test::TestStreamResponse::Message&, Status) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbClientReader, Closed) { + PwpbClientTestContext ctx; + PwpbClientReader call = + TestService::TestServerStreamRpc( + ctx.client(), + ctx.channel().id(), + {}, + FailIfOnNextCalled, + FailIfCalled, + FailIfCalled); + ASSERT_EQ(OkStatus(), call.Cancel()); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + + call.set_on_completed([](Status) {}); + call.set_on_next([](const test::TestStreamResponse::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbClientReaderWriter, Closed) { + PwpbClientTestContext ctx; + PwpbClientReaderWriter + call = TestService::TestBidirectionalStreamRpc( + ctx.client(), + ctx.channel().id(), + FailIfOnNextCalled, + FailIfCalled, + FailIfCalled); + ASSERT_EQ(OkStatus(), call.Cancel()); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Cancel()); + EXPECT_EQ(Status::FailedPrecondition(), call.CloseClientStream()); + + call.set_on_completed([](Status) {}); + call.set_on_next([](const test::TestStreamResponse::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbUnaryReceiver, CallbacksMoveCorrectly) { + PwpbClientTestContext ctx; + + struct { + test::TestResponse::Message payload = {.value = 12345678}; + std::optional status; + } reply; + + PwpbUnaryReceiver call_2; + { + PwpbUnaryReceiver call_1 = TestService::TestUnaryRpc( + ctx.client(), + ctx.channel().id(), + {}, + [&reply](const test::TestResponse::Message& response, Status status) { + reply.payload = response; + reply.status = status; + }); + + call_2 = std::move(call_1); + } + + ctx.server().SendResponse({.value = 9000}, + Status::NotFound()); + EXPECT_EQ(reply.payload.value, 9000); + EXPECT_EQ(reply.status, Status::NotFound()); +} + +TEST(PwpbClientReaderWriter, CallbacksMoveCorrectly) { + PwpbClientTestContext ctx; + + test::TestStreamResponse::Message payload = {.chunk = {}, .number = 13579}; + + PwpbClientReaderWriter + call_2; + { + PwpbClientReaderWriter call_1 = TestService::TestBidirectionalStreamRpc( + ctx.client(), + ctx.channel().id(), + [&payload](const test::TestStreamResponse::Message& response) { + payload = response; + }); + + call_2 = std::move(call_1); + } + + ctx.server().SendServerStream( + {.chunk = {}, .number = 5050}); + EXPECT_EQ(payload.number, 5050u); +} + +} // namespace +} // namespace pw::rpc + +PW_MODIFY_DIAGNOSTICS_POP(); diff --git a/pw_rpc/pwpb/codegen_test.cc b/pw_rpc/pwpb/codegen_test.cc new file mode 100644 index 0000000000..13f09b78fa --- /dev/null +++ b/pw_rpc/pwpb/codegen_test.cc @@ -0,0 +1,386 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "gtest/gtest.h" +#include "pw_preprocessor/compiler.h" +#include "pw_rpc/internal/hash.h" +#include "pw_rpc/internal/test_utils.h" +#include "pw_rpc/pwpb/test_method_context.h" +#include "pw_rpc_pwpb_private/internal_test_utils.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" + +PW_MODIFY_DIAGNOSTICS_PUSH(); +PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers"); + +namespace pw::rpc { +namespace test { + +class TestService final + : public pw_rpc::pwpb::TestService::Service { + public: + Status TestUnaryRpc(const TestRequest::Message& request, + TestResponse::Message& response) { + response.value = request.integer + 1; + return static_cast(request.status_code); + } + + void TestAnotherUnaryRpc( + const TestRequest::Message& request, + PwpbUnaryResponder& responder) { + TestResponse::Message response{}; + EXPECT_EQ(OkStatus(), + responder.Finish(response, TestUnaryRpc(request, response))); + } + + static void TestServerStreamRpc( + const TestRequest::Message& request, + ServerWriter& writer) { + for (int i = 0; i < request.integer; ++i) { + EXPECT_EQ( + OkStatus(), + writer.Write({.chunk = {}, .number = static_cast(i)})); + } + + EXPECT_EQ(OkStatus(), + writer.Finish(static_cast(request.status_code))); + } + + void TestClientStreamRpc( + ServerReader& + new_reader) { + reader = std::move(new_reader); + } + + void TestBidirectionalStreamRpc( + ServerReaderWriter& + new_reader_writer) { + reader_writer = std::move(new_reader_writer); + } + + ServerReader reader; + ServerReaderWriter + reader_writer; +}; + +} // namespace test + +namespace { + +using internal::ClientContextForTest; + +TEST(PwpbCodegen, CompilesProperly) { + test::TestService service; + EXPECT_EQ(service.id(), internal::Hash("pw.rpc.test.TestService")); + EXPECT_STREQ(service.name(), "TestService"); +} + +TEST(PwpbCodegen, Server_InvokeUnaryRpc) { + PW_PWPB_TEST_METHOD_CONTEXT(test::TestService, TestUnaryRpc) context; + + EXPECT_EQ(OkStatus(), + context.call({.integer = 123, .status_code = OkStatus().code()})); + + EXPECT_EQ(124, context.response().value); + + EXPECT_EQ(Status::InvalidArgument(), + context.call({.integer = 999, + .status_code = Status::InvalidArgument().code()})); + EXPECT_EQ(1000, context.response().value); +} + +TEST(PwpbCodegen, Server_InvokeAsyncUnaryRpc) { + PW_PWPB_TEST_METHOD_CONTEXT(test::TestService, TestAnotherUnaryRpc) context; + + context.call({.integer = 123, .status_code = OkStatus().code()}); + + EXPECT_EQ(OkStatus(), context.status()); + EXPECT_EQ(124, context.response().value); + + context.call( + {.integer = 999, .status_code = Status::InvalidArgument().code()}); + EXPECT_EQ(Status::InvalidArgument(), context.status()); + EXPECT_EQ(1000, context.response().value); +} + +TEST(PwpbCodegen, Server_InvokeServerStreamingRpc) { + PW_PWPB_TEST_METHOD_CONTEXT(test::TestService, TestServerStreamRpc) context; + + context.call({.integer = 0, .status_code = Status::Aborted().code()}); + + EXPECT_EQ(Status::Aborted(), context.status()); + EXPECT_TRUE(context.done()); + EXPECT_EQ(context.total_responses(), 0u); + + context.call({.integer = 4, .status_code = OkStatus().code()}); + + ASSERT_EQ(4u, context.responses().size()); + + for (size_t i = 0; i < context.responses().size(); ++i) { + EXPECT_EQ(context.responses()[i].number, i); + } + + EXPECT_EQ(OkStatus().code(), context.status()); +} + +TEST(PwpbCodegen, Server_InvokeServerStreamingRpc_ManualWriting) { + PW_PWPB_TEST_METHOD_CONTEXT(test::TestService, TestServerStreamRpc, 4) + context; + + ASSERT_EQ(4u, context.max_packets()); + + auto writer = context.writer(); + + EXPECT_EQ(OkStatus(), writer.Write({.chunk = {}, .number = 3})); + EXPECT_EQ(OkStatus(), writer.Write({.chunk = {}, .number = 6})); + EXPECT_EQ(OkStatus(), writer.Write({.chunk = {}, .number = 9})); + + EXPECT_FALSE(context.done()); + + EXPECT_EQ(OkStatus(), writer.Finish(Status::Cancelled())); + ASSERT_TRUE(context.done()); + EXPECT_EQ(Status::Cancelled(), context.status()); + + ASSERT_EQ(3u, context.responses().size()); + + EXPECT_EQ(context.responses()[0].number, 3u); + EXPECT_EQ(context.responses()[1].number, 6u); + EXPECT_EQ(context.responses()[2].number, 9u); +} + +TEST(PwpbCodegen, Server_InvokeClientStreamingRpc) { + PW_PWPB_TEST_METHOD_CONTEXT(test::TestService, TestClientStreamRpc) context; + + context.call(); + + test::TestRequest::Message request = {}; + context.service().reader.set_on_next( + [&request](const test::TestRequest::Message& req) { request = req; }); + + context.SendClientStream({.integer = -99, .status_code = 10}); + EXPECT_EQ(request.integer, -99); + EXPECT_EQ(request.status_code, 10u); + + ASSERT_EQ(OkStatus(), + context.service().reader.Finish({.chunk = {}, .number = 3}, + Status::Unimplemented())); + EXPECT_EQ(Status::Unimplemented(), context.status()); + EXPECT_EQ(context.response().number, 3u); +} + +TEST(PwpbCodegen, Server_InvokeBidirectionalStreamingRpc) { + PW_PWPB_TEST_METHOD_CONTEXT(test::TestService, TestBidirectionalStreamRpc) + context; + + context.call(); + + test::TestRequest::Message request = {}; + context.service().reader_writer.set_on_next( + [&request](const test::TestRequest::Message& req) { request = req; }); + + context.SendClientStream({.integer = -99, .status_code = 10}); + EXPECT_EQ(request.integer, -99); + EXPECT_EQ(request.status_code, 10u); + + ASSERT_EQ(OkStatus(), + context.service().reader_writer.Write({.chunk = {}, .number = 2})); + EXPECT_EQ(context.responses()[0].number, 2u); + + ASSERT_EQ(OkStatus(), + context.service().reader_writer.Finish(Status::NotFound())); + EXPECT_EQ(Status::NotFound(), context.status()); +} + +TEST(PwpbCodegen, ClientCall_DefaultConstructor) { + PwpbUnaryReceiver unary_call; + PwpbClientReader server_streaming_call; +} + +using TestServiceClient = test::pw_rpc::pwpb::TestService::Client; + +TEST(PwpbCodegen, Client_InvokesUnaryRpcWithCallback) { + constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService"); + constexpr uint32_t kMethodId = internal::Hash("TestUnaryRpc"); + + ClientContextForTest<128, 99, kServiceId, kMethodId> context; + + TestServiceClient test_client(context.client(), context.channel().id()); + + struct { + Status last_status = Status::Unknown(); + int response_value = -1; + } result; + + auto call = test_client.TestUnaryRpc( + {.integer = 123, .status_code = 0}, + [&result](const test::TestResponse::Message& response, Status status) { + result.last_status = status; + result.response_value = response.value; + }); + + EXPECT_TRUE(call.active()); + + EXPECT_EQ(context.output().total_packets(), 1u); + auto packet = + static_cast(context.output()) + .last_packet(); + EXPECT_EQ(packet.channel_id(), context.channel().id()); + EXPECT_EQ(packet.service_id(), kServiceId); + EXPECT_EQ(packet.method_id(), kMethodId); + PW_DECODE_PB(test::TestRequest, sent_proto, packet.payload()); + EXPECT_EQ(sent_proto.integer, 123); + + PW_ENCODE_PB(test::TestResponse, response, .value = 42); + EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response)); + EXPECT_EQ(result.last_status, OkStatus()); + EXPECT_EQ(result.response_value, 42); + + EXPECT_FALSE(call.active()); +} + +TEST(PwpbCodegen, Client_InvokesServerStreamingRpcWithCallback) { + constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService"); + constexpr uint32_t kMethodId = internal::Hash("TestServerStreamRpc"); + + ClientContextForTest<128, 99, kServiceId, kMethodId> context; + + TestServiceClient test_client(context.client(), context.channel().id()); + + struct { + bool active = true; + Status stream_status = Status::Unknown(); + int response_value = -1; + } result; + + auto call = test_client.TestServerStreamRpc( + {.integer = 123, .status_code = 0}, + [&result](const test::TestStreamResponse::Message& response) { + result.active = true; + result.response_value = response.number; + }, + [&result](Status status) { + result.active = false; + result.stream_status = status; + }); + + EXPECT_TRUE(call.active()); + + EXPECT_EQ(context.output().total_packets(), 1u); + auto packet = + static_cast(context.output()) + .last_packet(); + EXPECT_EQ(packet.channel_id(), context.channel().id()); + EXPECT_EQ(packet.service_id(), kServiceId); + EXPECT_EQ(packet.method_id(), kMethodId); + PW_DECODE_PB(test::TestRequest, sent_proto, packet.payload()); + EXPECT_EQ(sent_proto.integer, 123); + + PW_ENCODE_PB(test::TestStreamResponse, response, .chunk = {}, .number = 11u); + EXPECT_EQ(OkStatus(), context.SendServerStream(response)); + EXPECT_TRUE(result.active); + EXPECT_EQ(result.response_value, 11); + + EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound())); + EXPECT_FALSE(result.active); + EXPECT_EQ(result.stream_status, Status::NotFound()); +} + +TEST(PwpbCodegen, Client_StaticMethod_InvokesUnaryRpcWithCallback) { + constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService"); + constexpr uint32_t kMethodId = internal::Hash("TestUnaryRpc"); + + ClientContextForTest<128, 99, kServiceId, kMethodId> context; + + struct { + Status last_status = Status::Unknown(); + int response_value = -1; + } result; + + auto call = test::pw_rpc::pwpb::TestService::TestUnaryRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + [&result](const test::TestResponse::Message& response, Status status) { + result.last_status = status; + result.response_value = response.value; + }); + + EXPECT_TRUE(call.active()); + + EXPECT_EQ(context.output().total_packets(), 1u); + auto packet = + static_cast(context.output()) + .last_packet(); + EXPECT_EQ(packet.channel_id(), context.channel().id()); + EXPECT_EQ(packet.service_id(), kServiceId); + EXPECT_EQ(packet.method_id(), kMethodId); + PW_DECODE_PB(test::TestRequest, sent_proto, packet.payload()); + EXPECT_EQ(sent_proto.integer, 123); + + PW_ENCODE_PB(test::TestResponse, response, .value = 42); + EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response)); + EXPECT_EQ(result.last_status, OkStatus()); + EXPECT_EQ(result.response_value, 42); +} + +TEST(PwpbCodegen, Client_StaticMethod_InvokesServerStreamingRpcWithCallback) { + constexpr uint32_t kServiceId = internal::Hash("pw.rpc.test.TestService"); + constexpr uint32_t kMethodId = internal::Hash("TestServerStreamRpc"); + + ClientContextForTest<128, 99, kServiceId, kMethodId> context; + + struct { + bool active = true; + Status stream_status = Status::Unknown(); + int response_value = -1; + } result; + + auto call = test::pw_rpc::pwpb::TestService::TestServerStreamRpc( + context.client(), + context.channel().id(), + {.integer = 123, .status_code = 0}, + [&result](const test::TestStreamResponse::Message& response) { + result.active = true; + result.response_value = response.number; + }, + [&result](Status status) { + result.active = false; + result.stream_status = status; + }); + + EXPECT_TRUE(call.active()); + + EXPECT_EQ(context.output().total_packets(), 1u); + auto packet = + static_cast(context.output()) + .last_packet(); + EXPECT_EQ(packet.channel_id(), context.channel().id()); + EXPECT_EQ(packet.service_id(), kServiceId); + EXPECT_EQ(packet.method_id(), kMethodId); + PW_DECODE_PB(test::TestRequest, sent_proto, packet.payload()); + EXPECT_EQ(sent_proto.integer, 123); + + PW_ENCODE_PB(test::TestStreamResponse, response, .chunk = {}, .number = 11u); + EXPECT_EQ(OkStatus(), context.SendServerStream(response)); + EXPECT_TRUE(result.active); + EXPECT_EQ(result.response_value, 11); + + EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound())); + EXPECT_FALSE(result.active); + EXPECT_EQ(result.stream_status, Status::NotFound()); +} + +} // namespace +} // namespace pw::rpc + +PW_MODIFY_DIAGNOSTICS_POP(); diff --git a/pw_rpc/pwpb/docs.rst b/pw_rpc/pwpb/docs.rst new file mode 100644 index 0000000000..9096d60658 --- /dev/null +++ b/pw_rpc/pwpb/docs.rst @@ -0,0 +1,259 @@ +.. _module-pw_rpc_pw_protobuf: + +----------- +pw_protobuf +----------- +``pw_rpc`` can generate services which encode/decode RPC requests and responses +as ``pw_protobuf`` message structs + +Usage +===== +Define a ``pw_proto_library`` containing the .proto file defining your service +(and optionally other related protos), then depend on the ``pwpb_rpc`` +version of that library in the code implementing the service. + +.. code:: + + # chat/BUILD.gn + + import("$dir_pw_build/target_types.gni") + import("$dir_pw_protobuf_compiler/proto.gni") + + pw_proto_library("chat_protos") { + sources = [ "chat_protos/chat_service.proto" ] + } + + # Library that implements the Chat service. + pw_source_set("chat_service") { + sources = [ + "chat_service.cc", + "chat_service.h", + ] + public_deps = [ ":chat_protos.pwpb_rpc" ] + } + +A C++ header file is generated for each input .proto file, with the ``.proto`` +extension replaced by ``.rpc.pwpb.h``. For example, given the input file +``chat_protos/chat_service.proto``, the generated header file will be placed +at the include path ``"chat_protos/chat_service.rpc.pwpb.h"``. + +Generated code API +================== +All examples in this document use the following RPC service definition. + +.. code:: protobuf + + // chat/chat_protos/chat_service.proto + + syntax = "proto3"; + + service Chat { + // Returns information about a chatroom. + rpc GetRoomInformation(RoomInfoRequest) returns (RoomInfoResponse) {} + + // Lists all of the users in a chatroom. The response is streamed as there + // may be a large amount of users. + rpc ListUsersInRoom(ListUsersRequest) returns (stream ListUsersResponse) {} + + // Uploads a file, in chunks, to a chatroom. + rpc UploadFile(stream UploadFileRequest) returns (UploadFileResponse) {} + + // Sends messages to a chatroom while receiving messages from other users. + rpc Chat(stream ChatMessage) returns (stream ChatMessage) {} + } + +Server-side +----------- +A C++ class is generated for each service in the .proto file. The class is +located within a special ``pw_rpc::pwpb`` sub-namespace of the file's package. + +The generated class is a base class which must be derived to implement the +service's methods. The base class is templated on the derived class. + +.. code:: c++ + + #include "chat_protos/chat_service.rpc.pwpb.h" + + class ChatService final : public pw_rpc::pwpb::Chat::Service { + public: + // Implementations of the service's RPC methods; see below. + }; + +Unary RPC +^^^^^^^^^ +A unary RPC is implemented as a function which takes in the RPC's request struct +and populates a response struct to send back, with a status indicating whether +the request succeeded. + +.. code:: c++ + + pw::Status GetRoomInformation(const RoomInfoRequest::Message& request, + RoomInfoResponse::Message& response); + +Server streaming RPC +^^^^^^^^^^^^^^^^^^^^ +A server streaming RPC receives the client's request message alongside a +``ServerWriter``, used to stream back responses. + +.. code:: c++ + + void ListUsersInRoom(const ListUsersRequest::Message& request, + pw::rpc::ServerWriter& writer); + +The ``ServerWriter`` object is movable, and remains active until it is manually +closed or goes out of scope. The writer has a simple API to return responses: + +.. cpp:function:: Status PwpbServerWriter::Write(const T::Message& response) + + Writes a single response message to the stream. The returned status indicates + whether the write was successful. + +.. cpp:function:: void PwpbServerWriter::Finish(Status status = OkStatus()) + + Closes the stream and sends back the RPC's overall status to the client. + +Once a ``ServerWriter`` has been closed, all future ``Write`` calls will fail. + +.. attention:: + + Make sure to use ``std::move`` when passing the ``ServerWriter`` around to + avoid accidentally closing it and ending the RPC. + +Client streaming RPC +^^^^^^^^^^^^^^^^^^^^ +.. attention:: + + ``pw_rpc`` does not yet support client streaming RPCs. + +Bidirectional streaming RPC +^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. attention:: + + ``pw_rpc`` does not yet support bidirectional streaming RPCs. + +Client-side +----------- +A corresponding client class is generated for every service defined in the proto +file. To allow multiple types of clients to exist, it is placed under the +``pw_rpc::pwpb`` namespace. The ``Client`` class is nested under +``pw_rpc::pwpb::ServiceName``. For example, the ``Chat`` service would create +``pw_rpc::pwpb::Chat::Client``. + +Service clients are instantiated with a reference to the RPC client through +which they will send requests, and the channel ID they will use. + +.. code-block:: c++ + + // Nested under pw_rpc::pwpb::ServiceName. + class Client { + public: + Client(::pw::rpc::Client& client, uint32_t channel_id); + + GetRoomInformationCall GetRoomInformation( + const RoomInfoRequest::Message& request, + ::pw::Function on_response, + ::pw::Function on_rpc_error = nullptr); + + // ...and more (see below). + }; + +RPCs can also be invoked individually as free functions: + +.. code-block:: c++ + + GetRoomInformationCall call = pw_rpc::pwpb::Chat::GetRoomInformation( + client, channel_id, request, on_response, on_rpc_error); + +The client class has member functions for each method defined within the +service's protobuf descriptor. The arguments to these methods vary depending on +the type of RPC. Each method returns a ``PwpbClientCall`` object which stores +the context of the ongoing RPC call. For more information on ``ClientCall`` +objects, refer to the :ref:`core RPC docs `. The +type of the returned object is complex, so it is aliased using the method +name. + +.. admonition:: Callback invocation + + RPC callbacks are invoked synchronously from ``Client::ProcessPacket``. + +Method APIs +^^^^^^^^^^^ +The arguments provided when invoking a method depend on its type. + +Unary RPC +~~~~~~~~~ +A unary RPC call takes the request struct and a callback to invoke when a +response is received. The callback receives the RPC's status and response +struct. + +An optional second callback can be provided to handle internal errors. + +.. code-block:: c++ + + GetRoomInformationCall GetRoomInformation( + const RoomInfoRequest::Message& request, + ::pw::Function on_response, + ::pw::Function on_rpc_error = nullptr); + +Server streaming RPC +~~~~~~~~~~~~~~~~~~~~ +A server streaming RPC call takes the initial request struct and two callbacks. +The first is invoked on every stream response received, and the second is +invoked once the stream is complete with its overall status. + +An optional third callback can be provided to handle internal errors. + +.. code-block:: c++ + + ListUsersInRoomCall ListUsersInRoom( + const ListUsersRequest::Message& request, + ::pw::Function on_response, + ::pw::Function on_stream_end, + ::pw::Function on_rpc_error = nullptr); + +Example usage +^^^^^^^^^^^^^ +The following example demonstrates how to call an RPC method using a pw_protobuf +service client and receive the response. + +.. code-block:: c++ + + #include "chat_protos/chat_service.rpc.pwpb.h" + + namespace { + + using ChatClient = pw_rpc::pwpb::Chat::Client; + + MyChannelOutput output; + pw::rpc::Channel channels[] = {pw::rpc::Channel::Create<1>(&output)}; + pw::rpc::Client client(channels); + + // Callback function for GetRoomInformation. + void LogRoomInformation(const RoomInfoResponse::Message& response, + Status status); + + } // namespace + + void InvokeSomeRpcs() { + // Instantiate a service client to call Chat service methods on channel 1. + ChatClient chat_client(client, 1); + + // The RPC will remain active as long as `call` is alive. + auto call = chat_client.GetRoomInformation( + {.room = "pigweed"}, LogRoomInformation); + if (!call.active()) { + // The invocation may fail. This could occur due to an invalid channel ID, + // for example. The failure status is forwarded to the to call's + // on_rpc_error callback. + return; + } + + // For simplicity, block until the call completes. An actual implementation + // would likely std::move the call somewhere to keep it active while doing + // other work. + while (call.active()) { + Wait(); + } + + // Do other stuff now that we have the room information. + } diff --git a/pw_rpc/pwpb/echo_service_test.cc b/pw_rpc/pwpb/echo_service_test.cc new file mode 100644 index 0000000000..80651c65c3 --- /dev/null +++ b/pw_rpc/pwpb/echo_service_test.cc @@ -0,0 +1,41 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include + +#include "gtest/gtest.h" +#include "pw_rpc/echo_service_pwpb.h" +#include "pw_rpc/pwpb/test_method_context.h" + +namespace pw::rpc { +namespace { + +TEST(EchoService, Echo_EchoesRequestMessage) { + PW_PWPB_TEST_METHOD_CONTEXT(EchoService, Echo) context; + ASSERT_EQ(context.call({"Hello, world"}), OkStatus()); + EXPECT_EQ(std::string_view(context.response().msg.data(), + context.response().msg.size()), + "Hello, world"); +} + +TEST(EchoService, Echo_EmptyRequest) { + PW_PWPB_TEST_METHOD_CONTEXT(EchoService, Echo) context; + ASSERT_EQ(context.call({}), OkStatus()); + EXPECT_EQ(std::string_view(context.response().msg.data(), + context.response().msg.size()), + ""); +} + +} // namespace +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/fake_channel_output_test.cc b/pw_rpc/pwpb/fake_channel_output_test.cc new file mode 100644 index 0000000000..734376bd2a --- /dev/null +++ b/pw_rpc/pwpb/fake_channel_output_test.cc @@ -0,0 +1,97 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/pwpb/fake_channel_output.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "pw_rpc/internal/channel.h" +#include "pw_rpc/internal/packet.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" + +PW_MODIFY_DIAGNOSTICS_PUSH(); +PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers"); + +namespace pw::rpc::internal::test { +namespace { + +using rpc::test::pw_rpc::pwpb::TestService; +using Info = internal::MethodInfo; + +TEST(PwpbFakeChannelOutput, Requests) { + PwpbFakeChannelOutput<1> output; + + std::byte payload_buffer[32] = {}; + constexpr Info::Request request{.integer = -100, .status_code = 5}; + const StatusWithSize payload = + Info::serde().EncodeRequest(request, payload_buffer); + ASSERT_TRUE(payload.ok()); + + std::array buffer; + + auto packet = Packet(PacketType::REQUEST, + 1, + Info::kServiceId, + Info::kMethodId, + 999, + std::span(payload_buffer, payload.size())) + .Encode(buffer); + ASSERT_TRUE(packet.ok()); + + ASSERT_EQ(OkStatus(), output.Send(std::span(buffer).first(packet->size()))); + + ASSERT_TRUE(output.responses().empty()); + ASSERT_EQ(output.requests().size(), 1u); + + Info::Request sent = output.requests().front(); + EXPECT_EQ(sent.integer, -100); + EXPECT_EQ(sent.status_code, 5u); +} + +TEST(PwpbFakeChannelOutput, Responses) { + PwpbFakeChannelOutput<1> output; + + std::byte payload_buffer[32] = {}; + const Info::Response response{.value = -9876}; + const StatusWithSize payload = + Info::serde().EncodeResponse(response, payload_buffer); + ASSERT_TRUE(payload.ok()); + + std::array buffer; + + auto packet = Packet(PacketType::RESPONSE, + 1, + Info::kServiceId, + Info::kMethodId, + 999, + std::span(payload_buffer, payload.size())) + .Encode(buffer); + ASSERT_TRUE(packet.ok()); + + ASSERT_EQ(OkStatus(), output.Send(std::span(buffer).first(packet->size()))); + + ASSERT_EQ(output.responses().size(), 1u); + ASSERT_TRUE(output.requests().empty()); + + Info::Response sent = output.responses().front(); + EXPECT_EQ(sent.value, -9876); +} + +} // namespace +} // namespace pw::rpc::internal::test + +PW_MODIFY_DIAGNOSTICS_POP(); diff --git a/pw_rpc/pwpb/method_info_test.cc b/pw_rpc/pwpb/method_info_test.cc new file mode 100644 index 0000000000..4e96272062 --- /dev/null +++ b/pw_rpc/pwpb/method_info_test.cc @@ -0,0 +1,53 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/internal/method_info.h" + +#include "gtest/gtest.h" +#include "pw_rpc/internal/method_info_tester.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" +#include "pw_status/status.h" + +namespace pw::rpc { +namespace { + +class TestService final + : public test::pw_rpc::pwpb::TestService::Service { + public: + Status TestUnaryRpc(const test::TestRequest::Message&, + test::TestResponse::Message&) { + return OkStatus(); + } + + void TestAnotherUnaryRpc(const test::TestRequest::Message&, + PwpbUnaryResponder&) {} + + static void TestServerStreamRpc( + const test::TestRequest::Message&, + ServerWriter&) {} + + void TestClientStreamRpc(ServerReader&) {} + + void TestBidirectionalStreamRpc( + ServerReaderWriter&) {} +}; + +static_assert( + internal::MethodInfoTests() + .Pass()); + +} // namespace +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/method_lookup_test.cc b/pw_rpc/pwpb/method_lookup_test.cc new file mode 100644 index 0000000000..1b3c88412c --- /dev/null +++ b/pw_rpc/pwpb/method_lookup_test.cc @@ -0,0 +1,159 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "gtest/gtest.h" +#include "pw_rpc/pwpb/test_method_context.h" +#include "pw_rpc/raw/test_method_context.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" + +namespace pw::rpc { +namespace { + +class MixedService1 + : public test::pw_rpc::pwpb::TestService::Service { + public: + void TestUnaryRpc(ConstByteSpan, RawUnaryResponder& responder) { + std::byte response[5] = {}; + ASSERT_EQ(OkStatus(), responder.Finish(response, OkStatus())); + } + + void TestAnotherUnaryRpc(const test::TestRequest::Message&, + PwpbUnaryResponder&) { + called_async_unary_method = true; + } + + void TestServerStreamRpc(const test::TestRequest::Message&, + ServerWriter&) { + called_server_streaming_method = true; + } + + void TestClientStreamRpc(RawServerReader&) { + called_client_streaming_method = true; + } + + void TestBidirectionalStreamRpc( + ServerReaderWriter&) { + called_bidirectional_streaming_method = true; + } + + bool called_async_unary_method = false; + bool called_server_streaming_method = false; + bool called_client_streaming_method = false; + bool called_bidirectional_streaming_method = false; +}; + +class MixedService2 + : public test::pw_rpc::pwpb::TestService::Service { + public: + Status TestUnaryRpc(const test::TestRequest::Message&, + test::TestResponse::Message&) { + return Status::Unauthenticated(); + } + + void TestAnotherUnaryRpc(ConstByteSpan, RawUnaryResponder&) { + called_async_unary_method = true; + } + + void TestServerStreamRpc(ConstByteSpan, RawServerWriter&) { + called_server_streaming_method = true; + } + + void TestClientStreamRpc(ServerReader&) { + called_client_streaming_method = true; + } + + void TestBidirectionalStreamRpc(RawServerReaderWriter&) { + called_bidirectional_streaming_method = true; + } + + bool called_async_unary_method = false; + bool called_server_streaming_method = false; + bool called_client_streaming_method = false; + bool called_bidirectional_streaming_method = false; +}; + +TEST(MixedService1, CallRawMethod_SyncUnary) { + PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestUnaryRpc) context; + context.call({}); + EXPECT_EQ(OkStatus(), context.status()); + EXPECT_EQ(5u, context.response().size()); +} + +TEST(MixedService1, CallPwpbMethod_AsyncUnary) { + PW_PWPB_TEST_METHOD_CONTEXT(MixedService1, TestAnotherUnaryRpc) context; + ASSERT_FALSE(context.service().called_async_unary_method); + context.call({}); + EXPECT_TRUE(context.service().called_async_unary_method); +} + +TEST(MixedService1, CallPwpbMethod_ServerStreaming) { + PW_PWPB_TEST_METHOD_CONTEXT(MixedService1, TestServerStreamRpc) context; + ASSERT_FALSE(context.service().called_server_streaming_method); + context.call({}); + EXPECT_TRUE(context.service().called_server_streaming_method); +} + +TEST(MixedService1, CallRawMethod_ClientStreaming) { + PW_RAW_TEST_METHOD_CONTEXT(MixedService1, TestClientStreamRpc) context; + ASSERT_FALSE(context.service().called_client_streaming_method); + context.call(); + EXPECT_TRUE(context.service().called_client_streaming_method); +} + +TEST(MixedService1, CallPwpbMethod_BidirectionalStreaming) { + PW_PWPB_TEST_METHOD_CONTEXT(MixedService1, TestBidirectionalStreamRpc) + context; + ASSERT_FALSE(context.service().called_bidirectional_streaming_method); + context.call(); + EXPECT_TRUE(context.service().called_bidirectional_streaming_method); +} + +TEST(MixedService2, CallPwpbMethod_SyncUnary) { + PW_PWPB_TEST_METHOD_CONTEXT(MixedService2, TestUnaryRpc) context; + Status status = context.call({}); + EXPECT_EQ(Status::Unauthenticated(), status); +} + +TEST(MixedService2, CallRawMethod_AsyncUnary) { + PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestAnotherUnaryRpc) context; + ASSERT_FALSE(context.service().called_async_unary_method); + context.call({}); + EXPECT_TRUE(context.service().called_async_unary_method); +} + +TEST(MixedService2, CallRawMethod_ServerStreaming) { + PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestServerStreamRpc) context; + ASSERT_FALSE(context.service().called_server_streaming_method); + context.call({}); + EXPECT_TRUE(context.service().called_server_streaming_method); +} + +TEST(MixedService2, CallPwpbMethod_ClientStreaming) { + PW_PWPB_TEST_METHOD_CONTEXT(MixedService2, TestClientStreamRpc) context; + ASSERT_FALSE(context.service().called_client_streaming_method); + context.call(); + EXPECT_TRUE(context.service().called_client_streaming_method); +} + +TEST(MixedService2, CallRawMethod_BidirectionalStreaming) { + PW_RAW_TEST_METHOD_CONTEXT(MixedService2, TestBidirectionalStreamRpc) context; + ASSERT_FALSE(context.service().called_bidirectional_streaming_method); + context.call(); + EXPECT_TRUE(context.service().called_bidirectional_streaming_method); +} + +} // namespace +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/method_test.cc b/pw_rpc/pwpb/method_test.cc new file mode 100644 index 0000000000..4944a5c40d --- /dev/null +++ b/pw_rpc/pwpb/method_test.cc @@ -0,0 +1,425 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/pwpb/internal/method.h" + +#include + +#include "gtest/gtest.h" +#include "pw_rpc/internal/lock.h" +#include "pw_rpc/internal/method_impl_tester.h" +#include "pw_rpc/internal/test_utils.h" +#include "pw_rpc/pwpb/internal/method_union.h" +#include "pw_rpc/service.h" +#include "pw_rpc_pwpb_private/internal_test_utils.h" +#include "pw_rpc_test_protos/test.pwpb.h" + +PW_MODIFY_DIAGNOSTICS_PUSH(); +PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers"); + +namespace pw::rpc::internal { +namespace { + +using std::byte; + +struct FakePb {}; + +// Create a fake service for use with the MethodImplTester. +class TestPwpbService final : public Service { + public: + // Unary signatures + + Status Unary(const FakePb&, FakePb&) { return Status(); } + + static Status StaticUnary(const FakePb&, FakePb&) { return Status(); } + + void AsyncUnary(const FakePb&, PwpbUnaryResponder&) {} + + static void StaticAsyncUnary(const FakePb&, PwpbUnaryResponder&) {} + + Status UnaryWrongArg(FakePb&, FakePb&) { return Status(); } + + static void StaticUnaryVoidReturn(const FakePb&, FakePb&) {} + + // Server streaming signatures + + void ServerStreaming(const FakePb&, PwpbServerWriter&) {} + + static void StaticServerStreaming(const FakePb&, PwpbServerWriter&) {} + + int ServerStreamingBadReturn(const FakePb&, PwpbServerWriter&) { + return 5; + } + + static void StaticServerStreamingMissingArg(PwpbServerWriter&) {} + + // Client streaming signatures + + void ClientStreaming(PwpbServerReader&) {} + + static void StaticClientStreaming(PwpbServerReader&) {} + + int ClientStreamingBadReturn(PwpbServerReader&) { return 0; } + + static void StaticClientStreamingMissingArg() {} + + // Bidirectional streaming signatures + + void BidirectionalStreaming(PwpbServerReaderWriter&) {} + + static void StaticBidirectionalStreaming( + PwpbServerReaderWriter&) {} + + int BidirectionalStreamingBadReturn(PwpbServerReaderWriter&) { + return 0; + } + + static void StaticBidirectionalStreamingMissingArg() {} +}; + +struct WrongPb; + +// Test matches() rejects incorrect request/response types. +// clang-format off +static_assert(!PwpbMethod::template matches<&TestPwpbService::Unary, WrongPb, FakePb>()); +static_assert(!PwpbMethod::template matches<&TestPwpbService::Unary, FakePb, WrongPb>()); +static_assert(!PwpbMethod::template matches<&TestPwpbService::Unary, WrongPb, WrongPb>()); +static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticUnary, FakePb, WrongPb>()); + +static_assert(!PwpbMethod::template matches<&TestPwpbService::ServerStreaming, WrongPb, FakePb>()); +static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticServerStreaming, FakePb, WrongPb>()); + +static_assert(!PwpbMethod::template matches<&TestPwpbService::ClientStreaming, WrongPb, FakePb>()); +static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticClientStreaming, FakePb, WrongPb>()); + +static_assert(!PwpbMethod::template matches<&TestPwpbService::BidirectionalStreaming, WrongPb, FakePb>()); +static_assert(!PwpbMethod::template matches<&TestPwpbService::StaticBidirectionalStreaming, FakePb, WrongPb>()); +// clang-format on + +static_assert(MethodImplTests().Pass( + MatchesTypes(), + std::tuple(kPwpbMethodSerde))); + +template +class FakeServiceBase : public Service { + public: + FakeServiceBase(uint32_t id) : Service(id, kMethods) {} + + static constexpr std::array kMethods = { + PwpbMethod::SynchronousUnary<&Impl::DoNothing>( + 10u, + kPwpbMethodSerde<&pw::rpc::test::Empty::kMessageFields, + &pw::rpc::test::Empty::kMessageFields>), + PwpbMethod::AsynchronousUnary<&Impl::AddFive>( + 11u, + kPwpbMethodSerde<&pw::rpc::test::TestRequest::kMessageFields, + &pw::rpc::test::TestResponse::kMessageFields>), + PwpbMethod::ServerStreaming<&Impl::StartStream>( + 12u, + kPwpbMethodSerde<&pw::rpc::test::TestRequest::kMessageFields, + &pw::rpc::test::TestResponse::kMessageFields>), + PwpbMethod::ClientStreaming<&Impl::ClientStream>( + 13u, + kPwpbMethodSerde<&pw::rpc::test::TestRequest::kMessageFields, + &pw::rpc::test::TestResponse::kMessageFields>), + PwpbMethod::BidirectionalStreaming<&Impl::BidirectionalStream>( + 14u, + kPwpbMethodSerde<&pw::rpc::test::TestRequest::kMessageFields, + &pw::rpc::test::TestResponse::kMessageFields>)}; +}; + +class FakeService : public FakeServiceBase { + public: + FakeService(uint32_t id) : FakeServiceBase(id) {} + + Status DoNothing(const pw::rpc::test::Empty::Message&, + pw::rpc::test::Empty::Message&) { + return Status::Unknown(); + } + + void AddFive( + const pw::rpc::test::TestRequest::Message& request, + PwpbUnaryResponder& responder) { + last_request = request; + + if (fail_to_encode_async_unary_response) { + pw::rpc::test::TestResponse::Message response = {}; + response.repeated_field.SetEncoder( + [](const pw::rpc::test::TestResponse::StreamEncoder&) { + return Status::Internal(); + }); + ASSERT_EQ(OkStatus(), responder.Finish(response, Status::NotFound())); + } else { + ASSERT_EQ( + OkStatus(), + responder.Finish({.value = static_cast(request.integer + 5)}, + Status::Unauthenticated())); + } + } + + void StartStream( + const pw::rpc::test::TestRequest::Message& request, + PwpbServerWriter& writer) { + last_request = request; + last_writer = std::move(writer); + } + + void ClientStream( + PwpbServerReader& reader) { + last_reader = std::move(reader); + } + + void BidirectionalStream( + PwpbServerReaderWriter& + reader_writer) { + last_reader_writer = std::move(reader_writer); + } + + bool fail_to_encode_async_unary_response = false; + + pw::rpc::test::TestRequest::Message last_request; + PwpbServerWriter last_writer; + PwpbServerReader + last_reader; + PwpbServerReaderWriter + last_reader_writer; +}; + +constexpr const PwpbMethod& kSyncUnary = + std::get<0>(FakeServiceBase::kMethods).pwpb_method(); +constexpr const PwpbMethod& kAsyncUnary = + std::get<1>(FakeServiceBase::kMethods).pwpb_method(); +constexpr const PwpbMethod& kServerStream = + std::get<2>(FakeServiceBase::kMethods).pwpb_method(); +constexpr const PwpbMethod& kClientStream = + std::get<3>(FakeServiceBase::kMethods).pwpb_method(); +constexpr const PwpbMethod& kBidirectionalStream = + std::get<4>(FakeServiceBase::kMethods).pwpb_method(); + +TEST(PwpbMethod, AsyncUnaryRpc_SendsResponse) { + PW_ENCODE_PB( + pw::rpc::test::TestRequest, request, .integer = 123, .status_code = 0); + + ServerContextForTest context(kAsyncUnary); + rpc_lock().lock(); + kAsyncUnary.Invoke(context.get(), context.request(request)); + + const Packet& response = context.output().last_packet(); + EXPECT_EQ(response.status(), Status::Unauthenticated()); + + // Field 1 (encoded as 1 << 3) with 128 as the value. + constexpr std::byte expected[]{ + std::byte{0x08}, std::byte{0x80}, std::byte{0x01}}; + + EXPECT_EQ(sizeof(expected), response.payload().size()); + EXPECT_EQ(0, + std::memcmp(expected, response.payload().data(), sizeof(expected))); + + EXPECT_EQ(123, context.service().last_request.integer); +} + +TEST(PwpbMethod, SyncUnaryRpc_InvalidPayload_SendsError) { + std::array bad_payload{byte{0xFF}, byte{0xAA}, byte{0xDD}}; + + ServerContextForTest context(kSyncUnary); + rpc_lock().lock(); + kSyncUnary.Invoke(context.get(), context.request(bad_payload)); + + const Packet& packet = context.output().last_packet(); + EXPECT_EQ(PacketType::SERVER_ERROR, packet.type()); + EXPECT_EQ(Status::DataLoss(), packet.status()); + EXPECT_EQ(context.service_id(), packet.service_id()); + EXPECT_EQ(kSyncUnary.id(), packet.method_id()); +} + +TEST(PwpbMethod, AsyncUnaryRpc_ResponseEncodingFails_SendsInternalError) { + constexpr int64_t value = 0x7FFFFFFF'FFFFFF00ll; + PW_ENCODE_PB( + pw::rpc::test::TestRequest, request, .integer = value, .status_code = 0); + + ServerContextForTest context(kAsyncUnary); + context.service().fail_to_encode_async_unary_response = true; + + rpc_lock().lock(); + kAsyncUnary.Invoke(context.get(), context.request(request)); + + const Packet& packet = context.output().last_packet(); + EXPECT_EQ(PacketType::SERVER_ERROR, packet.type()); + EXPECT_EQ(Status::Internal(), packet.status()); + EXPECT_EQ(context.service_id(), packet.service_id()); + EXPECT_EQ(kAsyncUnary.id(), packet.method_id()); + + EXPECT_EQ(value, context.service().last_request.integer); +} + +TEST(PwpbMethod, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) { + PW_ENCODE_PB( + pw::rpc::test::TestRequest, request, .integer = 555, .status_code = 0); + + ServerContextForTest context(kServerStream); + + rpc_lock().lock(); + kServerStream.Invoke(context.get(), context.request(request)); + + EXPECT_EQ(0u, context.output().total_packets()); + EXPECT_EQ(555, context.service().last_request.integer); +} + +TEST(PwpbMethod, ServerWriter_SendsResponse) { + ServerContextForTest context(kServerStream); + + rpc_lock().lock(); + kServerStream.Invoke(context.get(), context.request({})); + + EXPECT_EQ(OkStatus(), context.service().last_writer.Write({.value = 100})); + + PW_ENCODE_PB(pw::rpc::test::TestResponse, payload, .value = 100); + std::array encoded_response = {}; + auto encoded = context.server_stream(payload).Encode(encoded_response); + ASSERT_EQ(OkStatus(), encoded.status()); + + ConstByteSpan sent_payload = context.output().last_packet().payload(); + EXPECT_TRUE(std::equal(payload.begin(), + payload.end(), + sent_payload.begin(), + sent_payload.end())); +} + +TEST(PwpbMethod, ServerWriter_WriteWhenClosed_ReturnsFailedPrecondition) { + ServerContextForTest context(kServerStream); + + rpc_lock().lock(); + kServerStream.Invoke(context.get(), context.request({})); + + EXPECT_EQ(OkStatus(), context.service().last_writer.Finish()); + EXPECT_TRUE(context.service() + .last_writer.Write({.value = 100}) + .IsFailedPrecondition()); +} + +TEST(PwpbMethod, ServerWriter_WriteAfterMoved_ReturnsFailedPrecondition) { + ServerContextForTest context(kServerStream); + + rpc_lock().lock(); + kServerStream.Invoke(context.get(), context.request({})); + PwpbServerWriter new_writer = + std::move(context.service().last_writer); + + EXPECT_EQ(OkStatus(), new_writer.Write({.value = 100})); + + EXPECT_EQ(Status::FailedPrecondition(), + context.service().last_writer.Write({.value = 100})); + EXPECT_EQ(Status::FailedPrecondition(), + context.service().last_writer.Finish()); + + EXPECT_EQ(OkStatus(), new_writer.Finish()); +} + +TEST(PwpbMethod, ServerStreamingRpc_ResponseEncodingFails_InternalError) { + ServerContextForTest context(kServerStream); + + rpc_lock().lock(); + kServerStream.Invoke(context.get(), context.request({})); + + EXPECT_EQ(OkStatus(), context.service().last_writer.Write({})); + + pw::rpc::test::TestResponse::Message response = {}; + response.repeated_field.SetEncoder( + [](const pw::rpc::test::TestResponse::StreamEncoder&) { + return Status::Internal(); + }); + EXPECT_EQ(Status::Internal(), context.service().last_writer.Write(response)); +} + +TEST(PwpbMethod, ServerReader_HandlesRequests) { + ServerContextForTest context(kClientStream); + + rpc_lock().lock(); + kClientStream.Invoke(context.get(), context.request({})); + + pw::rpc::test::TestRequest::Message request_struct{}; + context.service().last_reader.set_on_next( + [&request_struct](const pw::rpc::test::TestRequest::Message& req) { + request_struct = req; + }); + + PW_ENCODE_PB(pw::rpc::test::TestRequest, + request, + .integer = 1 << 30, + .status_code = 9); + std::array encoded_request = {}; + auto encoded = context.client_stream(request).Encode(encoded_request); + ASSERT_EQ(OkStatus(), encoded.status()); + ASSERT_EQ(OkStatus(), + context.server().ProcessPacket(*encoded, context.output())); + + EXPECT_EQ(request_struct.integer, 1 << 30); + EXPECT_EQ(request_struct.status_code, 9u); +} + +TEST(PwpbMethod, ServerReaderWriter_WritesResponses) { + ServerContextForTest context(kBidirectionalStream); + + rpc_lock().lock(); + kBidirectionalStream.Invoke(context.get(), context.request({})); + + EXPECT_EQ(OkStatus(), + context.service().last_reader_writer.Write({.value = 100})); + + PW_ENCODE_PB(pw::rpc::test::TestResponse, payload, .value = 100); + std::array encoded_response = {}; + auto encoded = context.server_stream(payload).Encode(encoded_response); + ASSERT_EQ(OkStatus(), encoded.status()); + + ConstByteSpan sent_payload = context.output().last_packet().payload(); + EXPECT_TRUE(std::equal(payload.begin(), + payload.end(), + sent_payload.begin(), + sent_payload.end())); +} + +TEST(PwpbMethod, ServerReaderWriter_HandlesRequests) { + ServerContextForTest context(kBidirectionalStream); + + rpc_lock().lock(); + kBidirectionalStream.Invoke(context.get(), context.request({})); + + pw::rpc::test::TestRequest::Message request_struct{}; + context.service().last_reader_writer.set_on_next( + [&request_struct](const pw::rpc::test::TestRequest::Message& req) { + request_struct = req; + }); + + PW_ENCODE_PB(pw::rpc::test::TestRequest, + request, + .integer = 1 << 29, + .status_code = 8); + std::array encoded_request = {}; + auto encoded = context.client_stream(request).Encode(encoded_request); + ASSERT_EQ(OkStatus(), encoded.status()); + ASSERT_EQ(OkStatus(), + context.server().ProcessPacket(*encoded, context.output())); + + EXPECT_EQ(request_struct.integer, 1 << 29); + EXPECT_EQ(request_struct.status_code, 8u); +} + +} // namespace +} // namespace pw::rpc::internal + +PW_MODIFY_DIAGNOSTICS_POP(); diff --git a/pw_rpc/pwpb/method_union_test.cc b/pw_rpc/pwpb/method_union_test.cc new file mode 100644 index 0000000000..cfae7d899f --- /dev/null +++ b/pw_rpc/pwpb/method_union_test.cc @@ -0,0 +1,170 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// + +#include "pw_rpc/pwpb/internal/method_union.h" + +#include + +#include "gtest/gtest.h" +#include "pw_rpc/internal/test_utils.h" +#include "pw_rpc_pwpb_private/internal_test_utils.h" +#include "pw_rpc_test_protos/test.pwpb.h" + +namespace pw::rpc::internal { +namespace { + +using std::byte; + +template +class FakeGeneratedService : public Service { + public: + constexpr FakeGeneratedService(uint32_t id) : Service(id, kMethods) {} + + static constexpr std::array kMethods = { + GetPwpbOrRawMethodFor<&Implementation::DoNothing, + MethodType::kUnary, + pw::rpc::test::Empty::Message, + pw::rpc::test::Empty::Message>( + 10u, + kPwpbMethodSerde<&pw::rpc::test::Empty::kMessageFields, + &pw::rpc::test::Empty::kMessageFields>), + GetPwpbOrRawMethodFor<&Implementation::RawStream, + MethodType::kServerStreaming, + pw::rpc::test::TestRequest::Message, + pw::rpc::test::TestResponse::Message>( + 11u, + kPwpbMethodSerde<&pw::rpc::test::TestRequest::kMessageFields, + &pw::rpc::test::TestResponse::kMessageFields>), + GetPwpbOrRawMethodFor<&Implementation::AddFive, + MethodType::kUnary, + pw::rpc::test::TestRequest::Message, + pw::rpc::test::TestResponse::Message>( + 12u, + kPwpbMethodSerde<&pw::rpc::test::TestRequest::kMessageFields, + &pw::rpc::test::TestResponse::kMessageFields>), + GetPwpbOrRawMethodFor<&Implementation::StartStream, + MethodType::kServerStreaming, + pw::rpc::test::TestRequest::Message, + pw::rpc::test::TestResponse::Message>( + 13u, + kPwpbMethodSerde<&pw::rpc::test::TestRequest::kMessageFields, + &pw::rpc::test::TestResponse::kMessageFields>), + }; +}; + +class FakeGeneratedServiceImpl + : public FakeGeneratedService { + public: + FakeGeneratedServiceImpl(uint32_t id) : FakeGeneratedService(id) {} + + Status AddFive(const pw::rpc::test::TestRequest::Message& request, + pw::rpc::test::TestResponse::Message& response) { + last_request = request; + response.value = request.integer + 5; + return Status::Unauthenticated(); + } + + void DoNothing(ConstByteSpan, RawUnaryResponder& responder) { + ASSERT_EQ(OkStatus(), responder.Finish({}, Status::Unknown())); + } + + void RawStream(ConstByteSpan, RawServerWriter& writer) { + last_raw_writer = std::move(writer); + } + + void StartStream( + const pw::rpc::test::TestRequest::Message& request, + PwpbServerWriter& writer) { + last_request = request; + last_writer = std::move(writer); + } + + pw::rpc::test::TestRequest::Message last_request; + PwpbServerWriter last_writer; + RawServerWriter last_raw_writer; +}; + +TEST(PwpbMethodUnion, Raw_CallsUnaryMethod) { + const Method& method = + std::get<0>(FakeGeneratedServiceImpl::kMethods).method(); + ServerContextForTest context(method); + rpc_lock().lock(); + method.Invoke(context.get(), context.request({})); + + const Packet& response = context.output().last_packet(); + EXPECT_EQ(response.status(), Status::Unknown()); +} + +TEST(PwpbMethodUnion, Raw_CallsServerStreamingMethod) { + PW_ENCODE_PB( + pw::rpc::test::TestRequest, request, .integer = 555, .status_code = 0); + + const Method& method = + std::get<1>(FakeGeneratedServiceImpl::kMethods).method(); + ServerContextForTest context(method); + + rpc_lock().lock(); + method.Invoke(context.get(), context.request(request)); + + EXPECT_TRUE(context.service().last_raw_writer.active()); + EXPECT_EQ(OkStatus(), context.service().last_raw_writer.Finish()); + EXPECT_EQ(context.output().last_packet().type(), PacketType::RESPONSE); +} + +TEST(PwpbMethodUnion, Pwpb_CallsUnaryMethod) { + PW_ENCODE_PB( + pw::rpc::test::TestRequest, request, .integer = 123, .status_code = 3); + + const Method& method = + std::get<2>(FakeGeneratedServiceImpl::kMethods).method(); + ServerContextForTest context(method); + rpc_lock().lock(); + method.Invoke(context.get(), context.request(request)); + + const Packet& response = context.output().last_packet(); + EXPECT_EQ(response.status(), Status::Unauthenticated()); + + // Field 1 (encoded as 1 << 3) with 128 as the value. + constexpr std::byte expected[]{ + std::byte{0x08}, std::byte{0x80}, std::byte{0x01}}; + + EXPECT_EQ(sizeof(expected), response.payload().size()); + EXPECT_EQ(0, + std::memcmp(expected, response.payload().data(), sizeof(expected))); + + EXPECT_EQ(123, context.service().last_request.integer); + EXPECT_EQ(3u, context.service().last_request.status_code); +} + +TEST(PwpbMethodUnion, Pwpb_CallsServerStreamingMethod) { + PW_ENCODE_PB( + pw::rpc::test::TestRequest, request, .integer = 555, .status_code = 0); + + const Method& method = + std::get<3>(FakeGeneratedServiceImpl::kMethods).method(); + ServerContextForTest context(method); + + rpc_lock().lock(); + method.Invoke(context.get(), context.request(request)); + + EXPECT_EQ(555, context.service().last_request.integer); + EXPECT_TRUE(context.service().last_writer.active()); + + EXPECT_EQ(OkStatus(), context.service().last_writer.Finish()); + EXPECT_EQ(context.output().last_packet().type(), PacketType::RESPONSE); +} + +} // namespace +} // namespace pw::rpc::internal diff --git a/pw_rpc/pwpb/public/pw_rpc/echo_service_pwpb.h b/pw_rpc/pwpb/public/pw_rpc/echo_service_pwpb.h new file mode 100644 index 0000000000..b15ef30549 --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/echo_service_pwpb.h @@ -0,0 +1,30 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include "pw_rpc/echo.rpc.pwpb.h" + +namespace pw::rpc { + +class EchoService final + : public pw_rpc::pwpb::EchoService::Service { + public: + Status Echo(const EchoMessage::Message& request, + EchoMessage::Message& response) { + response.msg = request.msg; + return OkStatus(); + } +}; + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h new file mode 100644 index 0000000000..6938bd4b48 --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/client_reader_writer.h @@ -0,0 +1,477 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +// This file defines the ClientReaderWriter, ClientReader, ClientWriter, +// and UnaryReceiver classes for the pw_protobuf RPC interface. These classes +// are used for bidirectional, client, and server streaming, and unary RPCs. +#pragma once + +#include "pw_bytes/span.h" +#include "pw_function/function.h" +#include "pw_rpc/channel.h" +#include "pw_rpc/internal/client_call.h" +#include "pw_rpc/pwpb/internal/common.h" + +namespace pw::rpc { +namespace internal { + +// internal::PwpbUnaryResponseClientCall extends +// internal::UnaryResponseClientCall by adding a method serializer/deserializer +// passed in to Start(), typed request messages to the Start() call, and an +// on_completed callback templated on the response type. +template +class PwpbUnaryResponseClientCall : public UnaryResponseClientCall { + public: + // Start() can be called with zero or one request objects. + template + static CallType Start(Endpoint& client, + uint32_t channel_id, + uint32_t service_id, + uint32_t method_id, + const PwpbMethodSerde& serde, + Function&& on_completed, + Function&& on_error, + const Request&... request) { + rpc_lock().lock(); + CallType call(client, channel_id, service_id, method_id, serde); + + call.set_on_completed_locked(std::move(on_completed)); + call.set_on_error_locked(std::move(on_error)); + + if constexpr (sizeof...(Request) == 0u) { + call.SendInitialClientRequest({}); + } else { + PwpbSendInitialRequest(call, serde.request(), request...); + } + + return call; + } + + // Give access to the serializer/deserializer object for converting requests + // and responses between the wire format and pw_protobuf structs. + const PwpbMethodSerde& serde() const { return *serde_; } + + protected: + // Derived classes allow default construction so that users can declare a + // variable into which to move client reader/writers from RPC calls. + constexpr PwpbUnaryResponseClientCall() = default; + + PwpbUnaryResponseClientCall(internal::Endpoint& client, + uint32_t channel_id, + uint32_t service_id, + uint32_t method_id, + MethodType type, + const PwpbMethodSerde& serde) + : UnaryResponseClientCall( + client, channel_id, service_id, method_id, type), + serde_(&serde) {} + + // Allow derived classes to be constructed moving another instance. + PwpbUnaryResponseClientCall(PwpbUnaryResponseClientCall&& other) + PW_LOCKS_EXCLUDED(rpc_lock()) { + *this = std::move(other); + } + + // Allow derived classes to use move assignment from another instance. + PwpbUnaryResponseClientCall& operator=(PwpbUnaryResponseClientCall&& other) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + MovePwpbUnaryResponseClientCallFrom(other); + return *this; + } + + // Implement moving by copying the serde pointer and on_completed function. + void MovePwpbUnaryResponseClientCallFrom(PwpbUnaryResponseClientCall& other) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + MoveUnaryResponseClientCallFrom(other); + serde_ = other.serde_; + set_on_completed_locked(std::move(other.pwpb_on_completed_)); + } + + void set_on_completed( + Function&& on_completed) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + set_on_completed_locked(std::move(on_completed)); + } + + // Sends a streamed request. + // Returns the following Status codes: + // + // OK - the request was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf protobuf + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + template + Status SendStreamRequest(const Request& request) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + if (!active_locked()) { + return Status::FailedPrecondition(); + } + + return PwpbSendStream(*this, request, serde().request()); + } + + private: + void set_on_completed_locked( + Function&& on_completed) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + pwpb_on_completed_ = std::move(on_completed); + + UnaryResponseClientCall::set_on_completed_locked( + [this](ConstByteSpan payload, Status status) { + if (pwpb_on_completed_) { + Response response{}; + const Status decode_status = + serde().DecodeResponse(payload, response); + if (decode_status.ok()) { + pwpb_on_completed_(response, status); + } else { + rpc_lock().lock(); + CallOnError(Status::DataLoss()); + } + } + }); + } + + const PwpbMethodSerde* serde_; + Function pwpb_on_completed_; +}; + +// internal::PwpbStreamResponseClientCall extends +// internal::StreamResponseClientCall by adding a method serializer/deserializer +// passed in to Start(), typed request messages to the Start() call, and an +// on_next callback templated on the response type. +template +class PwpbStreamResponseClientCall : public StreamResponseClientCall { + public: + // Start() can be called with zero or one request objects. + template + static CallType Start(Endpoint& client, + uint32_t channel_id, + uint32_t service_id, + uint32_t method_id, + const PwpbMethodSerde& serde, + Function&& on_next, + Function&& on_completed, + Function&& on_error, + const Request&... request) { + rpc_lock().lock(); + CallType call(client, channel_id, service_id, method_id, serde); + + call.set_on_next_locked(std::move(on_next)); + call.set_on_completed_locked(std::move(on_completed)); + call.set_on_error_locked(std::move(on_error)); + + if constexpr (sizeof...(Request) == 0u) { + call.SendInitialClientRequest({}); + } else { + PwpbSendInitialRequest(call, serde.request(), request...); + } + return call; + } + + // Give access to the serializer/deserializer object for converting requests + // and responses between the wire format and pw_protobuf structs. + const PwpbMethodSerde& serde() const { return *serde_; } + + protected: + // Derived classes allow default construction so that users can declare a + // variable into which to move client reader/writers from RPC calls. + constexpr PwpbStreamResponseClientCall() = default; + + PwpbStreamResponseClientCall(internal::Endpoint& client, + uint32_t channel_id, + uint32_t service_id, + uint32_t method_id, + MethodType type, + const PwpbMethodSerde& serde) + : StreamResponseClientCall( + client, channel_id, service_id, method_id, type), + serde_(&serde) {} + + // Allow derived classes to be constructed moving another instance. + PwpbStreamResponseClientCall(PwpbStreamResponseClientCall&& other) + PW_LOCKS_EXCLUDED(rpc_lock()) { + *this = std::move(other); + } + + // Allow derived classes to use move assignment from another instance. + PwpbStreamResponseClientCall& operator=(PwpbStreamResponseClientCall&& other) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + MovePwpbStreamResponseClientCallFrom(other); + return *this; + } + + // Implement moving by copying the serde pointer and on_next function. + void MovePwpbStreamResponseClientCallFrom(PwpbStreamResponseClientCall& other) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + MoveStreamResponseClientCallFrom(other); + serde_ = other.serde_; + set_on_next_locked(std::move(other.pwpb_on_next_)); + } + + void set_on_next(Function&& on_next) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + set_on_next_locked(std::move(on_next)); + } + + // Sends a streamed request. + // Returns the following Status codes: + // + // OK - the request was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf protobuf + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + template + Status SendStreamRequest(const Request& request) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + if (!active_locked()) { + return Status::FailedPrecondition(); + } + + return PwpbSendStream(*this, request, serde().request()); + } + + private: + void set_on_next_locked(Function&& on_next) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + pwpb_on_next_ = std::move(on_next); + + Call::set_on_next_locked([this](ConstByteSpan payload) { + if (pwpb_on_next_) { + Response response{}; + const Status status = serde().DecodeResponse(payload, response); + if (status.ok()) { + pwpb_on_next_(response); + } else { + rpc_lock().lock(); + CallOnError(Status::DataLoss()); + } + } + }); + } + + const PwpbMethodSerde* serde_; + Function pwpb_on_next_; +}; + +} // namespace internal + +// The PwpbClientReaderWriter is used to send and receive typed messages in a +// pw_protobuf bidirectional streaming RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbClientReaderWriter + : private internal::PwpbStreamResponseClientCall { + public: + // Allow default construction so that users can declare a variable into + // which to move client reader/writers from RPC calls. + constexpr PwpbClientReaderWriter() = default; + + PwpbClientReaderWriter(PwpbClientReaderWriter&&) = default; + PwpbClientReaderWriter& operator=(PwpbClientReaderWriter&&) = default; + + using internal::Call::active; + using internal::Call::channel_id; + + // Writes a request. Returns the following Status codes: + // + // OK - the request was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf message + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + Status Write(const Request& request) { + return internal::PwpbStreamResponseClientCall::SendStreamRequest( + request); + } + + using internal::Call::Cancel; + using internal::Call::CloseClientStream; + + // Functions for setting RPC event callbacks. + using internal::PwpbStreamResponseClientCall::set_on_next; + using internal::StreamResponseClientCall::set_on_completed; + using internal::StreamResponseClientCall::set_on_error; + + protected: + friend class internal::PwpbStreamResponseClientCall; + + PwpbClientReaderWriter(internal::Endpoint& client, + uint32_t channel_id_v, + uint32_t service_id, + uint32_t method_id, + const internal::PwpbMethodSerde& serde) + : internal::PwpbStreamResponseClientCall( + client, + channel_id_v, + service_id, + method_id, + MethodType::kBidirectionalStreaming, + serde) {} +}; + +// The PwpbClientReader is used to receive typed messages and send a typed +// response in a pw_protobuf client streaming RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbClientReader + : private internal::PwpbStreamResponseClientCall { + public: + // Allow default construction so that users can declare a variable into + // which to move client reader/writers from RPC calls. + constexpr PwpbClientReader() = default; + + PwpbClientReader(PwpbClientReader&&) = default; + PwpbClientReader& operator=(PwpbClientReader&&) = default; + + using internal::StreamResponseClientCall::active; + using internal::StreamResponseClientCall::channel_id; + + using internal::Call::Cancel; + + // Functions for setting RPC event callbacks. + using internal::PwpbStreamResponseClientCall::set_on_next; + using internal::StreamResponseClientCall::set_on_completed; + using internal::StreamResponseClientCall::set_on_error; + + private: + friend class internal::PwpbStreamResponseClientCall; + + PwpbClientReader(internal::Endpoint& client, + uint32_t channel_id_v, + uint32_t service_id, + uint32_t method_id, + const internal::PwpbMethodSerde& serde) + : internal::PwpbStreamResponseClientCall( + client, + channel_id_v, + service_id, + method_id, + MethodType::kServerStreaming, + serde) {} +}; + +// The PwpbClientWriter is used to send typed responses in a pw_protobuf server +// streaming RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbClientWriter + : private internal::PwpbUnaryResponseClientCall { + public: + // Allow default construction so that users can declare a variable into + // which to move client reader/writers from RPC calls. + constexpr PwpbClientWriter() = default; + + PwpbClientWriter(PwpbClientWriter&&) = default; + PwpbClientWriter& operator=(PwpbClientWriter&&) = default; + + using internal::UnaryResponseClientCall::active; + using internal::UnaryResponseClientCall::channel_id; + + // Writes a request. Returns the following Status codes: + // + // OK - the request was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf message + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + Status Write(const Request& request) { + return internal::PwpbUnaryResponseClientCall::SendStreamRequest( + request); + } + + using internal::Call::Cancel; + using internal::Call::CloseClientStream; + + // Functions for setting RPC event callbacks. + using internal::PwpbUnaryResponseClientCall::set_on_completed; + using internal::UnaryResponseClientCall::set_on_error; + + private: + friend class internal::PwpbUnaryResponseClientCall; + + PwpbClientWriter(internal::Endpoint& client, + uint32_t channel_id_v, + uint32_t service_id, + uint32_t method_id, + const internal::PwpbMethodSerde& serde) + : internal::PwpbUnaryResponseClientCall( + client, + channel_id_v, + service_id, + method_id, + MethodType::kClientStreaming, + serde) {} +}; + +// The PwpbUnaryReceiver is used to handle a typed response to a pw_protobuf +// unary RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbUnaryReceiver + : private internal::PwpbUnaryResponseClientCall { + public: + // Allow default construction so that users can declare a variable into + // which to move client reader/writers from RPC calls. + constexpr PwpbUnaryReceiver() = default; + + PwpbUnaryReceiver(PwpbUnaryReceiver&&) = default; + PwpbUnaryReceiver& operator=(PwpbUnaryReceiver&&) = default; + + using internal::Call::active; + using internal::Call::channel_id; + + // Functions for setting RPC event callbacks. + using internal::Call::set_on_error; + using internal::PwpbUnaryResponseClientCall::set_on_completed; + + using internal::Call::Cancel; + + private: + friend class internal::PwpbUnaryResponseClientCall; + + PwpbUnaryReceiver(internal::Endpoint& client, + uint32_t channel_id_v, + uint32_t service_id, + uint32_t method_id, + const internal::PwpbMethodSerde& serde) + : internal::PwpbUnaryResponseClientCall(client, + channel_id_v, + service_id, + method_id, + MethodType::kUnary, + serde) {} +}; + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/client_testing.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/client_testing.h new file mode 100644 index 0000000000..cc4e05bc3b --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/client_testing.h @@ -0,0 +1,114 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include +#include + +#include "pw_bytes/span.h" +#include "pw_rpc/client.h" +#include "pw_rpc/internal/method_info.h" +#include "pw_rpc/pwpb/fake_channel_output.h" +#include "pw_rpc/raw/client_testing.h" + +namespace pw::rpc { + +// TODO(pwbug/477): Document the client testing APIs. + +// Sends packets to an RPC client as if it were a pw_rpc server. Accepts +// payloads as pw_protobuf message structs. +class PwpbFakeServer : public FakeServer { + private: + template + using Response = typename internal::MethodInfo::Response; + + public: + using FakeServer::FakeServer; + + // Sends a response packet for a server or bidirectional streaming RPC to the + // client. + template + void SendResponse(Status status) const { + FakeServer::SendResponse(status); + } + + // Sends a response packet for a unary or client streaming streaming RPC to + // the client. + template )> + void SendResponse(const Response& payload, Status status) const { + std::byte buffer[kEncodeBufferSizeBytes] = {}; + FakeServer::SendResponse(EncodeResponse(payload, buffer), + status); + } + + // Sends a stream packet for a server or bidirectional streaming RPC to the + // client. + template )> + void SendServerStream(const Response& payload) const { + std::byte buffer[kEncodeBufferSizeBytes] = {}; + FakeServer::SendServerStream( + EncodeResponse(payload, buffer)); + } + + private: + template + static ConstByteSpan EncodeResponse(const Response& payload, + ByteSpan buffer) { + const StatusWithSize result = + internal::MethodInfo::serde().EncodeResponse(payload, buffer); + PW_ASSERT(result.ok()); + return std::span(buffer).first(result.size()); + } +}; + +// Instantiates a PwpbFakeServer, Client, Channel, and PwpbFakeChannelOutput +// for testing RPC client calls. These components may be used individually, but +// are instantiated together for convenience. +template +class PwpbClientTestContext { + public: + constexpr PwpbClientTestContext() + : channel_(Channel::Create(&channel_output_)), + client_(std::span(&channel_, 1)), + packet_buffer_{}, + fake_server_( + channel_output_, client_, kDefaultChannelId, packet_buffer_) {} + + const Channel& channel() const { return channel_; } + Channel& channel() { return channel_; } + + const PwpbFakeServer& server() const { return fake_server_; } + PwpbFakeServer& server() { return fake_server_; } + + const Client& client() const { return client_; } + Client& client() { return client_; } + + const auto& output() const { return channel_output_; } + auto& output() { return channel_output_; } + + private: + static constexpr uint32_t kDefaultChannelId = 1; + + PwpbFakeChannelOutput channel_output_; + Channel channel_; + Client client_; + std::byte packet_buffer_[kPacketEncodeBufferSizeBytes]; + PwpbFakeServer fake_server_; +}; + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/fake_channel_output.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/fake_channel_output.h new file mode 100644 index 0000000000..323fca8fd8 --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/fake_channel_output.h @@ -0,0 +1,189 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include + +#include "pw_containers/wrapped_iterator.h" +#include "pw_rpc/internal/fake_channel_output.h" +#include "pw_rpc/pwpb/internal/common.h" +#include "pw_rpc/pwpb/internal/method.h" + +namespace pw::rpc { +namespace internal::test::pwpb { + +// Forward declare for a friend statement. +template +class PwpbInvocationContext; + +} // namespace internal::test::pwpb + +// PwpbPayloadsView supports iterating over payloads as decoded pw_protobuf +// request or response message structs. +template +class PwpbPayloadsView { + public: + class iterator : public containers::WrappedIterator { + public: + // Access the payload (rather than packet) with operator*. + Payload operator*() const { + Payload payload{}; + PW_ASSERT(serde_ + .Decode(containers::WrappedIterator::value(), + payload) + .ok()); + return payload; + } + + private: + friend class PwpbPayloadsView; + + constexpr iterator(const PayloadsView::iterator& it, + const internal::PwpbSerde& serde) + : containers:: + WrappedIterator(it), + serde_(serde) {} + + internal::PwpbSerde serde_; + }; + + Payload operator[](size_t index) const { + Payload payload{}; + PW_ASSERT(serde_.Decode(view_[index], payload).ok()); + return payload; + } + + size_t size() const { return view_.size(); } + bool empty() const { return view_.empty(); } + + // Returns the first/last payload for the RPC. size() must be > 0. + Payload front() const { return *begin(); } + Payload back() const { return *std::prev(end()); } + + iterator begin() const { return iterator(view_.begin(), serde_); } + iterator end() const { return iterator(view_.end(), serde_); } + + private: + template + friend class PwpbFakeChannelOutput; + + template + PwpbPayloadsView(const internal::PwpbSerde& serde, Args&&... args) + : view_(args...), serde_(serde) {} + + PayloadsView view_; + internal::PwpbSerde serde_; +}; + +// A ChannelOutput implementation that stores the outgoing payloads and status. +template +class PwpbFakeChannelOutput final + : public internal::test::FakeChannelOutputBuffer { + private: + template + using Request = typename internal::MethodInfo::Request; + template + using Response = typename internal::MethodInfo::Response; + + public: + PwpbFakeChannelOutput() = default; + + // Iterates over request payloads from request or client stream packets. + // + // !!! WARNING !!! + // + // Access to the FakeChannelOutput through the PwpbPayloadsView is NOT + // synchronized! The PwpbPayloadsView is immediately invalidated if any + // thread accesses the FakeChannelOutput. + template + PwpbPayloadsView> requests( + uint32_t channel_id = Channel::kUnassignedChannelId) const { + constexpr internal::PacketType packet_type = + HasClientStream(internal::MethodInfo::kType) + ? internal::PacketType::CLIENT_STREAM + : internal::PacketType::REQUEST; + return PwpbPayloadsView>( + internal::MethodInfo::serde().request(), + internal::test::FakeChannelOutputBuffer< + kMaxPackets, + kPayloadsBufferSizeBytes>::packets(), + packet_type, + packet_type, + channel_id, + internal::MethodInfo::kServiceId, + internal::MethodInfo::kMethodId); + } + + // Iterates over response payloads from response or server stream packets. + // + // !!! WARNING !!! + // + // Access to the FakeChannelOutput through the PwpbPayloadsView is NOT + // synchronized! The PwpbPayloadsView is immediately invalidated if any + // thread accesses the FakeChannelOutput. + template + PwpbPayloadsView> responses( + uint32_t channel_id = Channel::kUnassignedChannelId) const { + constexpr internal::PacketType packet_type = + HasServerStream(internal::MethodInfo::kType) + ? internal::PacketType::SERVER_STREAM + : internal::PacketType::RESPONSE; + return PwpbPayloadsView>( + internal::MethodInfo::serde().response(), + internal::test::FakeChannelOutputBuffer< + kMaxPackets, + kPayloadsBufferSizeBytes>::packets(), + packet_type, + packet_type, + channel_id, + internal::MethodInfo::kServiceId, + internal::MethodInfo::kMethodId); + } + + template + Response last_response() const { + PwpbPayloadsView> payloads = responses(); + PW_ASSERT(!payloads.empty()); + return payloads.back(); + } + + private: + template + friend class internal::test::pwpb::PwpbInvocationContext; + + using internal::test::FakeChannelOutput::last_packet; + + template + PwpbPayloadsView payload_structs(const internal::PwpbSerde& serde, + MethodType type, + uint32_t channel_id, + uint32_t service_id, + uint32_t method_id) const { + return PwpbPayloadsView(serde, + internal::test::FakeChannelOutputBuffer< + kMaxPackets, + kPayloadsBufferSizeBytes>::packets(), + type, + channel_id, + service_id, + method_id); + } +}; + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/common.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/common.h new file mode 100644 index 0000000000..c147bf8503 --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/common.h @@ -0,0 +1,187 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include + +#include "pw_assert/check.h" +#include "pw_bytes/span.h" +#include "pw_protobuf/encoder.h" +#include "pw_protobuf/internal/codegen.h" +#include "pw_protobuf/stream_decoder.h" +#include "pw_rpc/internal/client_call.h" +#include "pw_rpc/internal/server_call.h" +#include "pw_status/status.h" +#include "pw_status/status_with_size.h" + +namespace pw::rpc::internal { + +using PwpbMessageDescriptor = const std::span*; + +// Serializer/deserializer for a pw_protobuf message. +class PwpbSerde { + public: + explicit constexpr PwpbSerde(PwpbMessageDescriptor table) : table_(table) {} + + PwpbSerde(const PwpbSerde&) = default; + PwpbSerde& operator=(const PwpbSerde&) = default; + + // Encodes a pw_protobuf struct to the serialized wire format. + template + StatusWithSize Encode(const Message& message, ByteSpan buffer) const { + return Encoder(buffer).Write(std::as_bytes(std::span(&message, 1)), table_); + } + + // Decodes a serialized protobuf into a pw_protobuf message struct. + template + Status Decode(ConstByteSpan buffer, Message& message) const { + return Decoder(buffer).Read(std::as_writable_bytes(std::span(&message, 1)), + table_); + } + + private: + class Encoder : public protobuf::MemoryEncoder { + public: + constexpr Encoder(ByteSpan buffer) : protobuf::MemoryEncoder(buffer) {} + + StatusWithSize Write(ConstByteSpan message, PwpbMessageDescriptor table) { + const auto status = protobuf::MemoryEncoder::Write(message, *table); + return StatusWithSize(status, size()); + } + }; + + class Decoder : public protobuf::StreamDecoder { + public: + constexpr Decoder(ConstByteSpan buffer) + : protobuf::StreamDecoder(reader_), reader_(buffer) {} + + Status Read(ByteSpan message, PwpbMessageDescriptor table) { + return protobuf::StreamDecoder::Read(message, *table); + } + + private: + stream::MemoryReader reader_; + }; + + PwpbMessageDescriptor table_; +}; + +// Serializer/deserializer for pw_protobuf request and response message structs +// within an RPC method. +class PwpbMethodSerde { + public: + constexpr PwpbMethodSerde(PwpbMessageDescriptor request_table, + PwpbMessageDescriptor response_table) + : request_serde_(request_table), response_serde_(response_table) {} + + PwpbMethodSerde(const PwpbMethodSerde&) = delete; + PwpbMethodSerde& operator=(const PwpbMethodSerde&) = delete; + + // Encodes the pw_protobuf request struct to the serialized wire format. + template + StatusWithSize EncodeRequest(const Request& request, ByteSpan buffer) const { + return request_serde_.Encode(request, buffer); + } + + // Encodes the pw_protobuf response struct to the serialized wire format. + template + StatusWithSize EncodeResponse(const Response& response, + ByteSpan buffer) const { + return response_serde_.Encode(response, buffer); + } + // Decodes a serialized protobuf into the pw_protobuf request struct. + template + Status DecodeRequest(ConstByteSpan buffer, Request& request) const { + return request_serde_.Decode(buffer, request); + } + + // Decodes a serialized protobuf into the pw_protobuf response struct. + template + Status DecodeResponse(ConstByteSpan buffer, Response& response) const { + return response_serde_.Decode(buffer, response); + } + + const PwpbSerde& request() const { return request_serde_; } + const PwpbSerde& response() const { return response_serde_; } + + private: + PwpbSerde request_serde_; + PwpbSerde response_serde_; +}; + +// Defines per-message struct type instance of the serializer/deserializer. +template +constexpr PwpbMethodSerde kPwpbMethodSerde(kRequest, kResponse); + +// Encodes a message struct into a payload buffer. +template +Result PwpbEncodeToPayloadBuffer(const Payload& payload, + PwpbSerde serde) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + ByteSpan buffer = GetPayloadBuffer(); + const StatusWithSize sws = serde.Encode(payload, buffer); + if (!sws.ok()) { + return sws.status(); + } + return buffer.first(sws.size()); +} + +// [Client] Encodes and sends the initial request message for the call. +// active() must be true. +template +void PwpbSendInitialRequest(ClientCall& call, + PwpbSerde serde, + const Request& request) + PW_UNLOCK_FUNCTION(rpc_lock()) { + PW_DCHECK(call.active_locked()); + + Result buffer = PwpbEncodeToPayloadBuffer(request, serde); + if (buffer.ok()) { + call.SendInitialClientRequest(*buffer); + } else { + call.HandleError(buffer.status()); + } +} + +// [Client/Server] Encodes and sends a client or server stream message. +// active() must be true. +template +Status PwpbSendStream(Call& call, const Payload& payload, PwpbSerde serde) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + Result buffer = PwpbEncodeToPayloadBuffer(payload, serde); + PW_TRY(buffer); + + return call.WriteLocked(*buffer); +} + +// [Server] Encodes and sends the final response message from an untyped +// ConstByteSpan. +// active() must be true. +template +Status PwpbSendFinalResponse(internal::ServerCall& call, + const Response& response, + Status status, + PwpbSerde serde) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + PW_DCHECK(call.active_locked()); + + Result buffer = PwpbEncodeToPayloadBuffer(response, serde); + if (!buffer.ok()) { + return call.CloseAndSendServerErrorLocked(Status::Internal()); + } + + return call.CloseAndSendResponseLocked(*buffer, status); +} + +} // namespace pw::rpc::internal diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method.h new file mode 100644 index 0000000000..0d774ac63d --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method.h @@ -0,0 +1,459 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include +#include +#include +#include + +#include "pw_bytes/span.h" +#include "pw_rpc/internal/call_context.h" +#include "pw_rpc/internal/lock.h" +#include "pw_rpc/internal/method.h" +#include "pw_rpc/internal/packet.h" +#include "pw_rpc/method_type.h" +#include "pw_rpc/pwpb/internal/common.h" +#include "pw_rpc/pwpb/server_reader_writer.h" +#include "pw_rpc/service.h" +#include "pw_status/status_with_size.h" + +namespace pw::rpc::internal { + +// Expected function signatures for user-implemented RPC functions. +template +using PwpbSynchronousUnary = Status(const Request&, Response&); + +template +using PwpbAsynchronousUnary = void(const Request&, + PwpbUnaryResponder&); + +template +using PwpbServerStreaming = void(const Request&, PwpbServerWriter&); + +template +using PwpbClientStreaming = void(PwpbServerReader&); + +template +using PwpbBidirectionalStreaming = + void(PwpbServerReaderWriter&); + +// The PwpbMethod class invokes user-defined service methods. When a +// pw::rpc::Server receives an RPC request packet, it looks up the matching +// PwpbMethod instance and calls its Invoke method, which eventually calls into +// the user-defined RPC function. +// +// A PwpbMethod instance is created for each user-defined RPC in the pw_rpc +// generated code. The PwpbMethod stores a pointer to the RPC function, +// a pointer to an "invoker" function that calls that function, and a +// reference to a serializer/deserializer initiiated with the message struct +// tables used to encode and decode request and response message structs. +class PwpbMethod : public Method { + public: + template + static constexpr bool matches() { + return std::conjunction_v< + std::is_same, PwpbMethod>, + std::is_same>, + std::is_same>>; + } + + // Creates a PwpbMethod for a synchronous unary RPC. + // TODO(pwbug/661): Find a way to reduce the number of monomorphized copies + // of this method. + template + static constexpr PwpbMethod SynchronousUnary(uint32_t id, + const PwpbMethodSerde& serde) { + // Define a wrapper around the user-defined function that takes the + // request and response protobuf structs as byte spans, and calls the + // implementation with the correct type. + // + // This wrapper is stored generically in the Function union, defined below. + // In optimized builds, the compiler inlines the user-defined function into + // this wrapper, elminating any overhead. + constexpr SynchronousUnaryFunction wrapper = + [](Service& service, const void* request, void* response) { + return CallMethodImplFunction( + service, + *reinterpret_cast*>(request), + *reinterpret_cast*>(response)); + }; + return PwpbMethod( + id, + SynchronousUnaryInvoker, Response>, + Function{.synchronous_unary = wrapper}, + serde); + } + + // Creates a PwpbMethod for an asynchronous unary RPC. + // TODO(pwbug/661): Find a way to reduce the number of monomorphized copies + // of this method. + template + static constexpr PwpbMethod AsynchronousUnary(uint32_t id, + const PwpbMethodSerde& serde) { + // Define a wrapper around the user-defined function that takes the + // request struct as a byte span, the response as a server call, and calls + // the implementation with the correct types. + // + // This wrapper is stored generically in the Function union, defined below. + // In optimized builds, the compiler inlines the user-defined function into + // this wrapper, elminating any overhead. + constexpr UnaryRequestFunction wrapper = + [](Service& service, + const void* request, + internal::PwpbServerCall& writer) { + return CallMethodImplFunction( + service, + *reinterpret_cast*>(request), + static_cast>&>(writer)); + }; + return PwpbMethod(id, + AsynchronousUnaryInvoker>, + Function{.unary_request = wrapper}, + serde); + } + + // Creates a PwpbMethod for a server-streaming RPC. + template + static constexpr PwpbMethod ServerStreaming(uint32_t id, + const PwpbMethodSerde& serde) { + // Define a wrapper around the user-defined function that takes the + // request struct as a byte span, the response as a server call, and calls + // the implementation with the correct types. + // + // This wrapper is stored generically in the Function union, defined below. + // In optimized builds, the compiler inlines the user-defined function into + // this wrapper, elminating any overhead. + constexpr UnaryRequestFunction wrapper = + [](Service& service, + const void* request, + internal::PwpbServerCall& writer) { + return CallMethodImplFunction( + service, + *reinterpret_cast*>(request), + static_cast>&>(writer)); + }; + return PwpbMethod(id, + ServerStreamingInvoker>, + Function{.unary_request = wrapper}, + serde); + } + + // Creates a PwpbMethod for a client-streaming RPC. + template + static constexpr PwpbMethod ClientStreaming(uint32_t id, + const PwpbMethodSerde& serde) { + // Define a wrapper around the user-defined function that takes the + // request as a server call, and calls the implementation with the correct + // types. + // + // This wrapper is stored generically in the Function union, defined below. + // In optimized builds, the compiler inlines the user-defined function into + // this wrapper, elminating any overhead. + constexpr StreamRequestFunction wrapper = [](Service& service, + internal::PwpbServerCall& + reader) { + return CallMethodImplFunction( + service, + static_cast, Response>&>( + reader)); + }; + return PwpbMethod(id, + ClientStreamingInvoker>, + Function{.stream_request = wrapper}, + serde); + } + + // Creates a PwpbMethod for a bidirectional-streaming RPC. + template + static constexpr PwpbMethod BidirectionalStreaming( + uint32_t id, const PwpbMethodSerde& serde) { + // Define a wrapper around the user-defined function that takes the + // request and response as a server call, and calls the implementation with + // the correct types. + // + // This wrapper is stored generically in the Function union, defined below. + // In optimized builds, the compiler inlines the user-defined function into + // this wrapper, elminating any overhead. + constexpr StreamRequestFunction wrapper = + [](Service& service, internal::PwpbServerCall& reader_writer) { + return CallMethodImplFunction( + service, + static_cast< + PwpbServerReaderWriter, Response>&>( + reader_writer)); + }; + return PwpbMethod(id, + BidirectionalStreamingInvoker>, + Function{.stream_request = wrapper}, + serde); + } + + // Represents an invalid method. Used to reduce error message verbosity. + static constexpr PwpbMethod Invalid() { + return {0, InvalidInvoker, {}, PwpbMethodSerde(nullptr, nullptr)}; + } + + // Give access to the serializer/deserializer object for converting requests + // and responses between the wire format and pw_protobuf structs. + const PwpbMethodSerde& serde() const { return serde_; } + + private: + // Generic function signature for synchronous unary RPCs. + using SynchronousUnaryFunction = Status (*)(Service&, + const void* request, + void* response); + + // Generic function signature for asynchronous unary and server streaming + // RPCs. + using UnaryRequestFunction = void (*)(Service&, + const void* request, + internal::PwpbServerCall& writer); + + // Generic function signature for client and bidirectional streaming RPCs. + using StreamRequestFunction = + void (*)(Service&, internal::PwpbServerCall& reader_writer); + + // The Function union stores a pointer to a generic version of the + // user-defined RPC function. Using a union instead of void* avoids + // reinterpret_cast, which keeps this class fully constexpr. + union Function { + SynchronousUnaryFunction synchronous_unary; + UnaryRequestFunction unary_request; + StreamRequestFunction stream_request; + }; + + constexpr PwpbMethod(uint32_t id, + Invoker invoker, + Function function, + const PwpbMethodSerde& serde) + : Method(id, invoker), function_(function), serde_(serde) {} + + template + void CallSynchronousUnary(const CallContext& context, + const Packet& request, + Request& request_struct, + Response& response_struct) const + PW_UNLOCK_FUNCTION(rpc_lock()) { + if (!DecodeRequest(context, request, request_struct).ok()) { + rpc_lock().unlock(); + return; + } + + internal::PwpbServerCall responder(context, MethodType::kUnary); + rpc_lock().unlock(); + const Status status = function_.synchronous_unary( + context.service(), &request_struct, &response_struct); + responder.SendUnaryResponse(response_struct, status).IgnoreError(); + } + + template + void CallUnaryRequest(const CallContext& context, + MethodType method_type, + const Packet& request, + Request& request_struct) const + PW_UNLOCK_FUNCTION(rpc_lock()) { + if (!DecodeRequest(context, request, request_struct).ok()) { + rpc_lock().unlock(); + return; + } + + internal::PwpbServerCall server_writer(context, method_type); + rpc_lock().unlock(); + function_.unary_request(context.service(), &request_struct, server_writer); + } + + // Decodes a request protobuf into the provided buffer. Sends an error packet + // if the request failed to decode. + template + Status DecodeRequest(const CallContext& context, + const Packet& request, + Request& request_struct) const + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + const auto status = serde_.DecodeRequest(request.payload(), request_struct); + if (status.ok()) { + return status; + } + + // The channel is known to exist. It was found when the request was + // processed and the lock has been held since, so GetInternalChannel cannot + // fail. + context.server() + .GetInternalChannel(context.channel_id()) + ->Send(Packet::ServerError(request, Status::DataLoss())) + .IgnoreError(); + return status; + } + + // Invoker function for synchronous unary RPCs. + template + static void SynchronousUnaryInvoker(const CallContext& context, + const Packet& request) + PW_UNLOCK_FUNCTION(rpc_lock()) { + Request request_struct{}; + Response response_struct{}; + + static_cast(context.method()) + .CallSynchronousUnary( + context, request, request_struct, response_struct); + } + + // Invoker function for asynchronous unary RPCs. + template + static void AsynchronousUnaryInvoker(const CallContext& context, + const Packet& request) + PW_UNLOCK_FUNCTION(rpc_lock()) { + Request request_struct{}; + + static_cast(context.method()) + .CallUnaryRequest(context, MethodType::kUnary, request, request_struct); + } + + // Invoker function for server streaming RPCs. + template + static void ServerStreamingInvoker(const CallContext& context, + const Packet& request) + PW_UNLOCK_FUNCTION(rpc_lock()) { + Request request_struct{}; + + static_cast(context.method()) + .CallUnaryRequest( + context, MethodType::kServerStreaming, request, request_struct); + } + + // Invoker function for client streaming RPCs. + template + static void ClientStreamingInvoker(const CallContext& context, const Packet&) + PW_UNLOCK_FUNCTION(rpc_lock()) { + internal::BasePwpbServerReader reader( + context, MethodType::kClientStreaming); + rpc_lock().unlock(); + static_cast(context.method()) + .function_.stream_request(context.service(), reader); + } + + // Invoker function for bidirectional streaming RPCs. + template + static void BidirectionalStreamingInvoker(const CallContext& context, + const Packet&) + PW_UNLOCK_FUNCTION(rpc_lock()) { + internal::BasePwpbServerReader reader_writer( + context, MethodType::kBidirectionalStreaming); + rpc_lock().unlock(); + static_cast(context.method()) + .function_.stream_request(context.service(), reader_writer); + } + + // Stores the user-defined RPC in a generic wrapper. + Function function_; + + // Serde used to encode and decode pw_protobuf structs. + const PwpbMethodSerde& serde_; +}; + +// MethodTraits specialization for a static synchronous unary method. +// TODO(pwbug/658): Further qualify this (and nanopb) definition so that they +// can co-exist in the same project. +template +struct MethodTraits*> { + using Implementation = PwpbMethod; + using Request = Req; + using Response = Res; + + static constexpr MethodType kType = MethodType::kUnary; + static constexpr bool kSynchronous = true; + + static constexpr bool kServerStreaming = false; + static constexpr bool kClientStreaming = false; +}; + +// MethodTraits specialization for a synchronous raw unary method. +template +struct MethodTraits(T::*)> + : MethodTraits*> { + using Service = T; +}; + +// MethodTraits specialization for a static asynchronous unary method. +template +struct MethodTraits*> + : MethodTraits*> { + static constexpr bool kSynchronous = false; +}; + +// MethodTraits specialization for an asynchronous unary method. +template +struct MethodTraits(T::*)> + : MethodTraits(T::*)> { + static constexpr bool kSynchronous = false; +}; + +// MethodTraits specialization for a static server streaming method. +template +struct MethodTraits*> { + using Implementation = PwpbMethod; + using Request = Req; + using Response = Resp; + + static constexpr MethodType kType = MethodType::kServerStreaming; + static constexpr bool kServerStreaming = true; + static constexpr bool kClientStreaming = false; +}; + +// MethodTraits specialization for a server streaming method. +template +struct MethodTraits(T::*)> + : MethodTraits*> { + using Service = T; +}; + +// MethodTraits specialization for a static server streaming method. +template +struct MethodTraits*> { + using Implementation = PwpbMethod; + using Request = Req; + using Response = Resp; + + static constexpr MethodType kType = MethodType::kClientStreaming; + static constexpr bool kServerStreaming = false; + static constexpr bool kClientStreaming = true; +}; + +// MethodTraits specialization for a server streaming method. +template +struct MethodTraits(T::*)> + : MethodTraits*> { + using Service = T; +}; + +// MethodTraits specialization for a static server streaming method. +template +struct MethodTraits*> { + using Implementation = PwpbMethod; + using Request = Req; + using Response = Resp; + + static constexpr MethodType kType = MethodType::kBidirectionalStreaming; + static constexpr bool kServerStreaming = true; + static constexpr bool kClientStreaming = true; +}; + +// MethodTraits specialization for a server streaming method. +template +struct MethodTraits(T::*)> + : MethodTraits*> { + using Service = T; +}; + +} // namespace pw::rpc::internal diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method_union.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method_union.h new file mode 100644 index 0000000000..f09da1989c --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/internal/method_union.h @@ -0,0 +1,57 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include "pw_rpc/internal/method_union.h" +#include "pw_rpc/pwpb/internal/common.h" +#include "pw_rpc/pwpb/internal/method.h" +#include "pw_rpc/raw/internal/method_union.h" + +namespace pw::rpc::internal { + +// MethodUnion which holds a pw_protobuf method or a raw method. +class PwpbMethodUnion : public MethodUnion { + public: + constexpr PwpbMethodUnion(RawMethod&& method) + : impl_({.raw = std::move(method)}) {} + constexpr PwpbMethodUnion(PwpbMethod&& method) + : impl_({.pwpb = std::move(method)}) {} + + constexpr const Method& method() const { return impl_.method; } + constexpr const RawMethod& raw_method() const { return impl_.raw; } + constexpr const PwpbMethod& pwpb_method() const { return impl_.pwpb; } + + private: + union { + Method method; + RawMethod raw; + PwpbMethod pwpb; + } impl_; +}; + +// Deduces the type of an implemented service method from its signature, and +// returns the appropriate MethodUnion object to invoke it. +template +constexpr auto GetPwpbOrRawMethodFor(uint32_t id, + const PwpbMethodSerde& serde) { + if constexpr (RawMethod::matches()) { + return GetMethodFor(id); + } else if constexpr (PwpbMethod::matches()) { + return GetMethodFor(id, serde); + } else { + return InvalidMethod(id); + } +}; + +} // namespace pw::rpc::internal diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/server_reader_writer.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/server_reader_writer.h new file mode 100644 index 0000000000..47b48024c0 --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/server_reader_writer.h @@ -0,0 +1,454 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +// This file defines the ServerReaderWriter, ServerReader, ServerWriter, and +// UnaryResponder classes for the pw_protobuf RPC interface. These classes are +// used for bidirectional, client, and server streaming, and unary RPCs. +#pragma once + +#include "pw_bytes/span.h" +#include "pw_function/function.h" +#include "pw_rpc/channel.h" +#include "pw_rpc/internal/lock.h" +#include "pw_rpc/internal/method_info.h" +#include "pw_rpc/internal/method_lookup.h" +#include "pw_rpc/internal/server_call.h" +#include "pw_rpc/method_type.h" +#include "pw_rpc/pwpb/internal/common.h" +#include "pw_rpc/server.h" + +namespace pw::rpc { +namespace internal { + +// Forward declarations for internal classes needed in friend statements. +namespace test { +template +class InvocationContext; +} // namespace test + +class PwpbMethod; + +// internal::PwpbServerCall extends internal::ServerCall by adding a method +// serializer/deserializer that is initialized based on the method context. +class PwpbServerCall : public internal::ServerCall { + public: + // Allow construction using a call context and method type which creates + // a working server call. + PwpbServerCall(const CallContext& context, MethodType type); + + // Sends a unary response. + // Returns the following Status codes: + // + // OK - the response was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf protobuf + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + template + Status SendUnaryResponse(const Response& response, Status status = OkStatus()) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + if (!active_locked()) { + return Status::FailedPrecondition(); + } + + return PwpbSendFinalResponse(*this, response, status, serde().response()); + } + + // Give access to the serializer/deserializer object for converting requests + // and responses between the wire format and pw_protobuf structs. + const PwpbMethodSerde& serde() const { return *serde_; } + + protected: + // Derived classes allow default construction so that users can declare a + // variable into which to move server reader/writers from RPC calls. + constexpr PwpbServerCall() : serde_(nullptr) {} + + // Allow derived classes to be constructed moving another instance. + PwpbServerCall(PwpbServerCall&& other) PW_LOCKS_EXCLUDED(rpc_lock()) { + *this = std::move(other); + } + + // Allow derived classes to use move assignment from another instance. + PwpbServerCall& operator=(PwpbServerCall&& other) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + MovePwpbServerCallFrom(other); + return *this; + } + + // Implement moving by copying the serde pointer. + void MovePwpbServerCallFrom(PwpbServerCall& other) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + MoveServerCallFrom(other); + serde_ = other.serde_; + } + + // Sends a streamed response. + // Returns the following Status codes: + // + // OK - the response was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf protobuf + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + template + Status SendStreamResponse(const Response& response) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + if (!active_locked()) { + return Status::FailedPrecondition(); + } + + return PwpbSendStream(*this, response, serde().response()); + } + + private: + const PwpbMethodSerde* serde_; +}; + +// internal::BasePwpbServerReader extends internal::PwpbServerCall further by +// adding an on_next callback templated on the request type. +template +class BasePwpbServerReader : public PwpbServerCall { + public: + BasePwpbServerReader(const CallContext& context, MethodType type) + : PwpbServerCall(context, type) {} + + protected: + // Allow default construction so that users can declare a variable into + // which to move server reader/writers from RPC calls. + constexpr BasePwpbServerReader() = default; + + // Allow derived classes to be constructed moving another instance. + BasePwpbServerReader(BasePwpbServerReader&& other) + PW_LOCKS_EXCLUDED(rpc_lock()) { + *this = std::move(other); + } + + // Allow derived classes to use move assignment from another instance. + BasePwpbServerReader& operator=(BasePwpbServerReader&& other) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + MoveBasePwpbServerReaderFrom(other); + return *this; + } + + // Implement moving by copying the on_next function. + void MoveBasePwpbServerReaderFrom(BasePwpbServerReader& other) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + MovePwpbServerCallFrom(other); + set_on_next_locked(std::move(other.pwpb_on_next_)); + } + + void set_on_next(Function&& on_next) + PW_LOCKS_EXCLUDED(rpc_lock()) { + LockGuard lock(rpc_lock()); + set_on_next_locked(std::move(on_next)); + } + + private: + void set_on_next_locked(Function&& on_next) + PW_EXCLUSIVE_LOCKS_REQUIRED(rpc_lock()) { + pwpb_on_next_ = std::move(on_next); + + Call::set_on_next_locked([this](ConstByteSpan payload) { + if (pwpb_on_next_) { + Request request{}; + const Status status = serde().DecodeRequest(payload, request); + if (status.ok()) { + pwpb_on_next_(request); + } + } + }); + } + + Function pwpb_on_next_; +}; + +} // namespace internal + +// The PwpbServerReaderWriter is used to send and receive typed messages in a +// pw_protobuf bidirectional streaming RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbServerReaderWriter : private internal::BasePwpbServerReader { + public: + // Creates a PwpbServerReaderWriter that is ready to send responses for a + // particular RPC. This can be used for testing or to send responses to an RPC + // that has not been started by a client. + template + [[nodiscard]] static PwpbServerReaderWriter Open(Server& server, + uint32_t channel_id, + ServiceImpl& service) { + using MethodInfo = internal::MethodInfo; + static_assert(std::is_same_v, + "The request type of a PwpbServerReaderWriter must match " + "the method."); + static_assert(std::is_same_v, + "The response type of a PwpbServerReaderWriter must match " + "the method."); + internal::LockGuard lock(internal::rpc_lock()); + return {server.OpenContext( + channel_id, + service, + internal::MethodLookup::GetPwpbMethod())}; + } + + // Allow default construction so that users can declare a variable into + // which to move server reader/writers from RPC calls. + constexpr PwpbServerReaderWriter() = default; + + PwpbServerReaderWriter(PwpbServerReaderWriter&&) = default; + PwpbServerReaderWriter& operator=(PwpbServerReaderWriter&&) = default; + + using internal::Call::active; + using internal::Call::channel_id; + + // Functions for setting RPC event callbacks. + using internal::Call::set_on_error; + using internal::BasePwpbServerReader::set_on_next; + using internal::ServerCall::set_on_client_stream_end; + + // Writes a response. Returns the following Status codes: + // + // OK - the response was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf message + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + Status Write(const Response& response) { + return internal::PwpbServerCall::SendStreamResponse(response); + } + + Status Finish(Status status = OkStatus()) { + return internal::Call::CloseAndSendResponse(status); + } + + private: + template + friend class internal::test::InvocationContext; + + friend class internal::PwpbMethod; + + PwpbServerReaderWriter(const internal::CallContext& context, + MethodType type = MethodType::kBidirectionalStreaming) + : internal::BasePwpbServerReader(context, type) {} +}; + +// The PwpbServerReader is used to receive typed messages and send a typed +// response in a pw_protobuf client streaming RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbServerReader : private internal::BasePwpbServerReader { + public: + // Creates a PwpbServerReader that is ready to send a response to a particular + // RPC. This can be used for testing or to finish an RPC that has not been + // started by the client. + template + [[nodiscard]] static PwpbServerReader Open(Server& server, + uint32_t channel_id, + ServiceImpl& service) { + using MethodInfo = internal::MethodInfo; + static_assert(std::is_same_v, + "The request type of a PwpbServerReader must match " + "the method."); + static_assert(std::is_same_v, + "The response type of a PwpbServerReader must match " + "the method."); + internal::LockGuard lock(internal::rpc_lock()); + return {server.OpenContext( + channel_id, + service, + internal::MethodLookup::GetPwpbMethod())}; + } + + // Allow default construction so that users can declare a variable into + // which to move server reader/writers from RPC calls. + constexpr PwpbServerReader() = default; + + PwpbServerReader(PwpbServerReader&&) = default; + PwpbServerReader& operator=(PwpbServerReader&&) = default; + + using internal::Call::active; + using internal::Call::channel_id; + + // Functions for setting RPC event callbacks. + using internal::Call::set_on_error; + using internal::BasePwpbServerReader::set_on_next; + using internal::ServerCall::set_on_client_stream_end; + + // Sends the response. Returns the following Status codes: + // + // OK - the response was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf message + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + Status Finish(const Response& response, Status status = OkStatus()) { + return internal::PwpbServerCall::SendUnaryResponse(response, status); + } + + private: + template + friend class internal::test::InvocationContext; + + friend class internal::PwpbMethod; + + PwpbServerReader(const internal::CallContext& context) + : internal::BasePwpbServerReader(context, + MethodType::kClientStreaming) {} +}; + +// The PwpbServerWriter is used to send typed responses in a pw_protobuf server +// streaming RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbServerWriter : private internal::PwpbServerCall { + public: + // Creates a PwpbServerWriter that is ready to send responses for a particular + // RPC. This can be used for testing or to send responses to an RPC that has + // not been started by a client. + template + [[nodiscard]] static PwpbServerWriter Open(Server& server, + uint32_t channel_id, + ServiceImpl& service) { + using MethodInfo = internal::MethodInfo; + static_assert(std::is_same_v, + "The response type of a PwpbServerWriter must match " + "the method."); + internal::LockGuard lock(internal::rpc_lock()); + return {server.OpenContext( + channel_id, + service, + internal::MethodLookup::GetPwpbMethod())}; + } + + // Allow default construction so that users can declare a variable into + // which to move server reader/writers from RPC calls. + constexpr PwpbServerWriter() = default; + + PwpbServerWriter(PwpbServerWriter&&) = default; + PwpbServerWriter& operator=(PwpbServerWriter&&) = default; + + using internal::Call::active; + using internal::Call::channel_id; + + // Functions for setting RPC event callbacks. + using internal::Call::set_on_error; + using internal::ServerCall::set_on_client_stream_end; + + // Writes a response. Returns the following Status codes: + // + // OK - the response was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf message + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + Status Write(const Response& response) { + return internal::PwpbServerCall::SendStreamResponse(response); + } + + Status Finish(Status status = OkStatus()) { + return internal::Call::CloseAndSendResponse(status); + } + + private: + template + friend class internal::test::InvocationContext; + + friend class internal::PwpbMethod; + + PwpbServerWriter(const internal::CallContext& context) + : internal::PwpbServerCall(context, MethodType::kServerStreaming) {} +}; + +// The PwpbUnaryResponder is used to send a typed response in a pw_protobuf +// unary RPC. +// +// These classes use private inheritance to hide the internal::Call API while +// allow direct use of its public and protected functions. +template +class PwpbUnaryResponder : private internal::PwpbServerCall { + public: + // Creates a PwpbUnaryResponder that is ready to send responses for a + // particular RPC. This can be used for testing or to send responses to an + // RPC that has not been started by a client. + template + [[nodiscard]] static PwpbUnaryResponder Open(Server& server, + uint32_t channel_id, + ServiceImpl& service) { + using MethodInfo = internal::MethodInfo; + static_assert(std::is_same_v, + "The response type of a PwpbUnaryResponder must match " + "the method."); + internal::LockGuard lock(internal::rpc_lock()); + return {server.OpenContext( + channel_id, + service, + internal::MethodLookup::GetPwpbMethod())}; + } + + // Allow default construction so that users can declare a variable into + // which to move server reader/writers from RPC calls. + constexpr PwpbUnaryResponder() = default; + + PwpbUnaryResponder(PwpbUnaryResponder&&) = default; + PwpbUnaryResponder& operator=(PwpbUnaryResponder&&) = default; + + using internal::ServerCall::active; + using internal::ServerCall::channel_id; + + // Functions for setting RPC event callbacks. + using internal::Call::set_on_error; + using internal::ServerCall::set_on_client_stream_end; + + // Sends the response. Returns the following Status codes: + // + // OK - the response was successfully sent + // FAILED_PRECONDITION - the writer is closed + // INTERNAL - pw_rpc was unable to encode the pw_protobuf message + // other errors - the ChannelOutput failed to send the packet; the error + // codes are determined by the ChannelOutput implementation + // + Status Finish(const Response& response, Status status = OkStatus()) { + return internal::PwpbServerCall::SendUnaryResponse(response, status); + } + + private: + template + friend class internal::test::InvocationContext; + + friend class internal::PwpbMethod; + + PwpbUnaryResponder(const internal::CallContext& context) + : internal::PwpbServerCall(context, MethodType::kUnary) {} +}; + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/public/pw_rpc/pwpb/test_method_context.h b/pw_rpc/pwpb/public/pw_rpc/pwpb/test_method_context.h new file mode 100644 index 0000000000..2cce00af02 --- /dev/null +++ b/pw_rpc/pwpb/public/pw_rpc/pwpb/test_method_context.h @@ -0,0 +1,392 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include +#include +#include +#include +#include + +#include "pw_preprocessor/arguments.h" +#include "pw_rpc/internal/hash.h" +#include "pw_rpc/internal/method_lookup.h" +#include "pw_rpc/internal/test_method_context.h" +#include "pw_rpc/pwpb/fake_channel_output.h" +#include "pw_rpc/pwpb/internal/method.h" +#include "pw_rpc/pwpb/server_reader_writer.h" + +namespace pw::rpc { + +// Declares a context object that may be used to invoke an RPC. The context is +// declared with the name of the implemented service and the method to invoke. +// The RPC can then be invoked with the call method. +// +// For a unary RPC, context.call(request) returns the status, and the response +// struct can be accessed via context.response(). +// +// PW_PWPB_TEST_METHOD_CONTEXT(my::CoolService, TheMethod) context; +// EXPECT_EQ(OkStatus(), context.call({.some_arg = 123}).status()); +// EXPECT_EQ(500, context.response().some_response_value); +// +// For a unary RPC with repeated fields in the response, pw_protobuf uses a +// callback field called when parsing the response as many times as the +// field is present in the protobuf. To set the callback create the Response +// struct and pass it to the response method: +// +// PW_PWPB_TEST_METHOD_CONTEXT(my::CoolService, TheMethod) context; +// EXPECT_EQ(OkStatus(), context.call({.some_arg = 123}).status()); +// +// TheMethodResponse::Message response{}; +// response.repeated_field.SetDecoder([](TheMethod::StreamDecoder& decoder) { +// PW_TRY_ASSIGN(const auto value, decoder.ReadValue()); +// EXPECT_EQ(value, 123); +// return OkStatus(); +// }); +// context.response(response); // Callbacks called from here. +// +// For a server streaming RPC, context.call(request) invokes the method. As in a +// normal RPC, the method completes when the ServerWriter's Finish method is +// called (or it goes out of scope). +// +// PW_PWPB_TEST_METHOD_CONTEXT(my::CoolService, TheStreamingMethod) context; +// context.call({.some_arg = 123}); +// +// EXPECT_TRUE(context.done()); // Check that the RPC completed +// EXPECT_EQ(OkStatus(), context.status()); // Check the status +// +// EXPECT_EQ(3u, context.responses().size()); +// EXPECT_EQ(123, context.responses()[0].value); // check individual responses +// +// for (const MyResponse& response : context.responses()) { +// // iterate over the responses +// } +// +// PW_PWPB_TEST_METHOD_CONTEXT forwards its constructor arguments to the +// underlying service. For example: +// +// PW_PWPB_TEST_METHOD_CONTEXT(MyService, Go) context(service, args); +// +// PW_PWPB_TEST_METHOD_CONTEXT takes one optional argument: +// +// size_t kMaxPackets: maximum packets to store +// +// Example: +// +// PW_PWPB_TEST_METHOD_CONTEXT(MyService, BestMethod, 3, 256) context; +// ASSERT_EQ(3u, context.responses().max_size()); +// +#define PW_PWPB_TEST_METHOD_CONTEXT(service, method, ...) \ + ::pw::rpc::PwpbTestMethodContext + +template +class PwpbTestMethodContext; + +namespace internal::test::pwpb { + +// Collects everything needed to invoke a particular RPC. +template +class PwpbInvocationContext + : public InvocationContext< + PwpbFakeChannelOutput, + Service, + kMethodId> { + private: + using Base = InvocationContext< + PwpbFakeChannelOutput, + Service, + kMethodId>; + + public: + using Request = internal::Request; + using Response = internal::Response; + + // Gives access to the RPC's most recent response. + Response response() const { + Response response{}; + PW_ASSERT(kMethodInfo.serde() + .DecodeResponse(Base::responses().back(), response) + .ok()); + return response; + } + + // Gives access to the RPC's most recent response using passed Response object + // to parse using pw_protobuf. Use this version when you need to set callback + // fields in the Response object before parsing. + void response(Response& response) const { + PW_ASSERT(kMethodInfo.serde() + .DecodeResponse(Base::responses().back(), response) + .ok()); + } + + PwpbPayloadsView responses() const { + return Base::output().template payload_structs( + kMethodInfo.serde().response(), + MethodTraits::kType, + Base::channel_id(), + Base::service().id(), + kMethodId); + } + + protected: + template + PwpbInvocationContext(Args&&... args) + : Base(kMethodInfo, + MethodTraits::kType, + std::forward(args)...) {} + + template + void SendClientStream(const Request& request) PW_LOCKS_EXCLUDED(rpc_lock()) { + std::array buffer; + Base::SendClientStream(std::span(buffer).first( + kMethodInfo.serde().EncodeRequest(request, buffer).size())); + } + + private: + static constexpr PwpbMethod kMethodInfo = + MethodLookup::GetPwpbMethod(); +}; + +// Method invocation context for a unary RPC. Returns the status in +// call_context() and provides the response through the response() method. +template +class UnaryContext : public PwpbInvocationContext { + private: + using Base = PwpbInvocationContext; + + public: + using Request = typename Base::Request; + using Response = typename Base::Response; + + template + UnaryContext(Args&&... args) : Base(std::forward(args)...) {} + + // Invokes the RPC with the provided request. Returns the status. + auto call(const Request& request) { + if constexpr (MethodTraits::kSynchronous) { + Base::output().clear(); + + PwpbUnaryResponder responder = + Base::template GetResponder>(); + Response response = {}; + Status status = + CallMethodImplFunction(Base::service(), request, response); + PW_ASSERT(responder.Finish(response, status).ok()); + return status; + + } else { + Base::template call>(request); + } + } +}; + +// Method invocation context for a server streaming RPC. +template +class ServerStreamingContext + : public PwpbInvocationContext { + private: + using Base = PwpbInvocationContext; + + public: + using Request = typename Base::Request; + using Response = typename Base::Response; + + template + ServerStreamingContext(Args&&... args) : Base(std::forward(args)...) {} + + // Invokes the RPC with the provided request. + void call(const Request& request) { + Base::template call>(request); + } + + // Returns a server writer which writes responses into the context's buffer. + // This should not be called alongside call(); use one or the other. + PwpbServerWriter writer() { + return Base::template GetResponder>(); + } +}; + +// Method invocation context for a client streaming RPC. +template +class ClientStreamingContext + : public PwpbInvocationContext { + private: + using Base = PwpbInvocationContext; + + public: + using Request = typename Base::Request; + using Response = typename Base::Response; + + template + ClientStreamingContext(Args&&... args) : Base(std::forward(args)...) {} + + // Invokes the RPC. + void call() { + Base::template call>(); + } + + // Returns a server reader which writes responses into the context's buffer. + // This should not be called alongside call(); use one or the other. + PwpbServerReader reader() { + return Base::template GetResponder>(); + } + + // Allow sending client streaming packets. + using Base::SendClientStream; + using Base::SendClientStreamEnd; +}; + +// Method invocation context for a bidirectional streaming RPC. +template +class BidirectionalStreamingContext + : public PwpbInvocationContext { + private: + using Base = PwpbInvocationContext; + + public: + using Request = typename Base::Request; + using Response = typename Base::Response; + + template + BidirectionalStreamingContext(Args&&... args) + : Base(std::forward(args)...) {} + + // Invokes the RPC. + void call() { + Base::template call>(); + } + + // Returns a server reader which writes responses into the context's buffer. + // This should not be called alongside call(); use one or the other. + PwpbServerReaderWriter reader_writer() { + return Base::template GetResponder< + PwpbServerReaderWriter>(); + } + + // Allow sending client streaming packets. + using Base::SendClientStream; + using Base::SendClientStreamEnd; +}; + +// Alias to select the type of the context object to use based on which type of +// RPC it is for. +template +using Context = std::tuple_element_t< + static_cast(internal::MethodTraits::kType), + std::tuple< + UnaryContext, + ServerStreamingContext, + ClientStreamingContext, + BidirectionalStreamingContext>>; + +} // namespace internal::test::pwpb + +template +class PwpbTestMethodContext + : public internal::test::pwpb::Context { + public: + // Forwards constructor arguments to the service class. + template + PwpbTestMethodContext(ServiceArgs&&... service_args) + : internal::test::pwpb::Context( + std::forward(service_args)...) {} +}; + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/pw_rpc_pwpb_private/internal_test_utils.h b/pw_rpc/pwpb/pw_rpc_pwpb_private/internal_test_utils.h new file mode 100644 index 0000000000..6f95f9d503 --- /dev/null +++ b/pw_rpc/pwpb/pw_rpc_pwpb_private/internal_test_utils.h @@ -0,0 +1,69 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include +#include +#include + +#include "pw_status/status.h" +#include "pw_stream/memory_stream.h" + +namespace pw::rpc::internal { + +// Encodes a protobuf to a local span named by result from a list of pw_protobuf +// struct initializers. Note that the proto namespace is passed, not the name +// of the struct --- ie. exclude the "::Message" suffix. +// +// PW_ENCODE_PB(pw::rpc::TestProto, encoded, .value = 42); +// +#define PW_ENCODE_PB(proto, result, ...) \ + _PW_ENCODE_PB_EXPAND(proto, result, __LINE__, __VA_ARGS__) + +#define _PW_ENCODE_PB_EXPAND(proto, result, unique, ...) \ + _PW_ENCODE_PB_IMPL(proto, result, unique, __VA_ARGS__) + +#define _PW_ENCODE_PB_IMPL(proto, result, unique, ...) \ + std::array _pb_buffer_##unique{}; \ + const std::span result = \ + ::pw::rpc::internal::EncodeProtobuf( \ + proto::Message{__VA_ARGS__}, _pb_buffer_##unique) + +template +std::span EncodeProtobuf(const Message& message, + std::span buffer) { + MemoryEncoder encoder(buffer); + EXPECT_EQ(encoder.Write(message), OkStatus()); + return buffer.first(encoder.size()); +} + +// Decodes a protobuf to a pw_protobuf struct named by result. Note that the +// proto namespace is passed, not the name of the struct --- ie. exclude the +// "::Message" suffix. +// +// PW_DECODE_PB(pw::rpc::TestProto, decoded, buffer); +// +#define PW_DECODE_PB(proto, result, buffer) \ + proto::Message result; \ + ::pw::rpc::internal::DecodeProtobuf( \ + buffer, result); + +template +void DecodeProtobuf(std::span buffer, Message& message) { + stream::MemoryReader reader(buffer); + EXPECT_EQ(StreamDecoder(reader).Read(message), OkStatus()); +} + +} // namespace pw::rpc::internal diff --git a/pw_rpc/pwpb/serde_test.cc b/pw_rpc/pwpb/serde_test.cc new file mode 100644 index 0000000000..b9ae4aa9fe --- /dev/null +++ b/pw_rpc/pwpb/serde_test.cc @@ -0,0 +1,58 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include + +#include "gtest/gtest.h" +#include "pw_rpc/pwpb/internal/common.h" +#include "pw_rpc_test_protos/test.pwpb.h" +#include "pw_status/status_with_size.h" + +namespace pw::rpc::internal { +namespace { + +constexpr PwpbSerde kTestRequest(&test::TestRequest::kMessageFields); +constexpr test::TestRequest::Message kProto{.integer = 3, .status_code = 0}; + +TEST(PwpbSerde, Encode) { + std::byte buffer[32] = {}; + + StatusWithSize result = kTestRequest.Encode(kProto, buffer); + EXPECT_EQ(OkStatus(), result.status()); + EXPECT_EQ(result.size(), 4u); + EXPECT_EQ(buffer[0], std::byte{1} << 3); + EXPECT_EQ(buffer[1], std::byte{3}); + // pw_protobuf encodes zero fields + EXPECT_EQ(buffer[2], std::byte{2} << 3); + EXPECT_EQ(buffer[3], std::byte{0}); +} + +TEST(PwpbSerde, Encode_TooSmall) { + std::byte buffer[1] = {}; + EXPECT_EQ(Status::ResourceExhausted(), + kTestRequest.Encode(kProto, buffer).status()); +} + +TEST(PwpbSerde, Decode) { + constexpr std::byte buffer[]{std::byte{1} << 3, std::byte{3}}; + test::TestRequest::Message proto = {}; + + EXPECT_EQ(OkStatus(), kTestRequest.Decode(buffer, proto)); + + EXPECT_EQ(3, proto.integer); + EXPECT_EQ(0u, proto.status_code); +} + +} // namespace +} // namespace pw::rpc::internal diff --git a/pw_rpc/pwpb/server_callback_test.cc b/pw_rpc/pwpb/server_callback_test.cc new file mode 100644 index 0000000000..7e5e27fe35 --- /dev/null +++ b/pw_rpc/pwpb/server_callback_test.cc @@ -0,0 +1,91 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include + +#include "gtest/gtest.h" +#include "pw_containers/vector.h" +#include "pw_rpc/pwpb/test_method_context.h" +#include "pw_rpc/service.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" + +namespace pw::rpc { + +class TestServiceImpl final + : public test::pw_rpc::pwpb::TestService::Service { + public: + Status TestUnaryRpc(const test::TestRequest::Message&, + test::TestResponse::Message& response) { + response.value = 42; + return OkStatus(); + } + + Status TestAnotherUnaryRpc(const test::TestRequest::Message&, + test::TestResponse::Message& response) { + response.value = 42; + response.repeated_field.SetEncoder( + [](test::TestResponse::StreamEncoder& encoder) { + constexpr std::array kValues = {7, 8, 9}; + return encoder.WriteRepeatedField(kValues); + }); + return OkStatus(); + } + + void TestServerStreamRpc( + const test::TestRequest::Message&, + PwpbServerWriter&) {} + + void TestClientStreamRpc( + PwpbServerReader&) {} + + void TestBidirectionalStreamRpc( + PwpbServerReaderWriter&) {} +}; + +TEST(PwpbTestMethodContext, ResponseWithoutCallbacks) { + // Calling response() without an argument returns a Response struct without + // any callbacks set. + PW_PWPB_TEST_METHOD_CONTEXT(TestServiceImpl, TestUnaryRpc) ctx; + ASSERT_EQ(ctx.call({}), OkStatus()); + + test::TestResponse::Message response = ctx.response(); + EXPECT_EQ(42, response.value); +} + +TEST(PwpbTestMethodContext, ResponseWithCallbacks) { + PW_PWPB_TEST_METHOD_CONTEXT(TestServiceImpl, TestAnotherUnaryRpc) ctx; + ASSERT_EQ(ctx.call({}), OkStatus()); + + // To decode a response object that requires to set callbacks, pass it to the + // response() method as a parameter. + pw::Vector values{}; + + test::TestResponse::Message response{}; + response.repeated_field.SetDecoder( + [&values](test::TestResponse::StreamDecoder& decoder) { + return decoder.ReadRepeatedField(values); + }); + ctx.response(response); + + EXPECT_EQ(42, response.value); + + EXPECT_EQ(3u, values.size()); + EXPECT_EQ(7u, values[0]); + EXPECT_EQ(8u, values[1]); + EXPECT_EQ(9u, values[2]); +} + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/server_reader_writer.cc b/pw_rpc/pwpb/server_reader_writer.cc new file mode 100644 index 0000000000..311a058390 --- /dev/null +++ b/pw_rpc/pwpb/server_reader_writer.cc @@ -0,0 +1,25 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/pwpb/server_reader_writer.h" + +#include "pw_rpc/pwpb/internal/method.h" + +namespace pw::rpc::internal { + +PwpbServerCall::PwpbServerCall(const CallContext& context, MethodType type) + : internal::ServerCall(context, type), + serde_(&static_cast(context.method()).serde()) {} + +} // namespace pw::rpc::internal diff --git a/pw_rpc/pwpb/server_reader_writer_test.cc b/pw_rpc/pwpb/server_reader_writer_test.cc new file mode 100644 index 0000000000..b16d6a746d --- /dev/null +++ b/pw_rpc/pwpb/server_reader_writer_test.cc @@ -0,0 +1,297 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/pwpb/server_reader_writer.h" + +#include "gtest/gtest.h" +#include "pw_rpc/pwpb/fake_channel_output.h" +#include "pw_rpc/pwpb/test_method_context.h" +#include "pw_rpc/service.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" + +namespace pw::rpc { + +class TestServiceImpl final + : public test::pw_rpc::pwpb::TestService::Service { + public: + Status TestUnaryRpc(const test::TestRequest::Message&, + test::TestResponse::Message&) { + return OkStatus(); + } + + void TestAnotherUnaryRpc(const test::TestRequest::Message&, + PwpbUnaryResponder&) {} + + void TestServerStreamRpc( + const test::TestRequest::Message&, + PwpbServerWriter&) {} + + void TestClientStreamRpc( + PwpbServerReader&) {} + + void TestBidirectionalStreamRpc( + PwpbServerReaderWriter&) {} +}; + +template +struct ReaderWriterTestContext { + using Info = internal::MethodInfo; + + static constexpr uint32_t kChannelId = 1; + + ReaderWriterTestContext() + : channel(Channel::Create(&output)), + server(std::span(&channel, 1)) {} + + TestServiceImpl service; + PwpbFakeChannelOutput<4> output; + Channel channel; + Server server; +}; + +using test::pw_rpc::pwpb::TestService; + +TEST(PwpbUnaryResponder, DefaultConstructed) { + PwpbUnaryResponder call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Finish({}, OkStatus())); + + call.set_on_error([](Status) {}); +} + +TEST(PwpbServerWriter, DefaultConstructed) { + PwpbServerWriter call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Finish(OkStatus())); + + call.set_on_error([](Status) {}); +} + +TEST(PwpbServerReader, DefaultConstructed) { + PwpbServerReader + call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Finish({}, OkStatus())); + + call.set_on_next([](const test::TestRequest::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbServerReaderWriter, DefaultConstructed) { + PwpbServerReaderWriter + call; + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Finish(OkStatus())); + + call.set_on_next([](const test::TestRequest::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbUnaryResponder, Closed) { + ReaderWriterTestContext ctx; + PwpbUnaryResponder call = + PwpbUnaryResponder::Open< + TestService::TestUnaryRpc>(ctx.server, ctx.channel.id(), ctx.service); + ASSERT_EQ(OkStatus(), call.Finish({}, OkStatus())); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Finish({}, OkStatus())); + + call.set_on_error([](Status) {}); +} + +TEST(PwpbServerWriter, Closed) { + ReaderWriterTestContext ctx; + PwpbServerWriter call = + PwpbServerWriter::Open< + TestService::TestServerStreamRpc>( + ctx.server, ctx.channel.id(), ctx.service); + ASSERT_EQ(OkStatus(), call.Finish(OkStatus())); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Finish(OkStatus())); + + call.set_on_error([](Status) {}); +} + +TEST(PwpbServerReader, Closed) { + ReaderWriterTestContext ctx; + PwpbServerReader call = PwpbServerReader:: + Open( + ctx.server, ctx.channel.id(), ctx.service); + ASSERT_EQ(OkStatus(), call.Finish({}, OkStatus())); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Finish({}, OkStatus())); + + call.set_on_next([](const test::TestRequest::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbServerReaderWriter, Closed) { + ReaderWriterTestContext ctx; + PwpbServerReaderWriter call = + PwpbServerReaderWriter:: + Open( + ctx.server, ctx.channel.id(), ctx.service); + ASSERT_EQ(OkStatus(), call.Finish(OkStatus())); + + ASSERT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); + + EXPECT_EQ(Status::FailedPrecondition(), call.Write({})); + EXPECT_EQ(Status::FailedPrecondition(), call.Finish(OkStatus())); + + call.set_on_next([](const test::TestRequest::Message&) {}); + call.set_on_error([](Status) {}); +} + +TEST(PwpbUnaryResponder, Open_ReturnsUsableResponder) { + ReaderWriterTestContext ctx; + PwpbUnaryResponder responder = + PwpbUnaryResponder::Open< + TestService::TestUnaryRpc>(ctx.server, ctx.channel.id(), ctx.service); + + ASSERT_EQ(OkStatus(), + responder.Finish({.value = 4321, .repeated_field = {}})); + + EXPECT_EQ(ctx.output.last_response().value, 4321); + EXPECT_EQ(ctx.output.last_status(), OkStatus()); +} + +TEST(PwpbServerWriter, Open_ReturnsUsableWriter) { + ReaderWriterTestContext ctx; + PwpbServerWriter responder = + PwpbServerWriter::Open< + TestService::TestServerStreamRpc>( + ctx.server, ctx.channel.id(), ctx.service); + + ASSERT_EQ(OkStatus(), responder.Write({.chunk = {}, .number = 321})); + ASSERT_EQ(OkStatus(), responder.Finish()); + + EXPECT_EQ(ctx.output.last_response().number, + 321u); + EXPECT_EQ(ctx.output.last_status(), OkStatus()); +} + +TEST(PwpbServerReader, Open_ReturnsUsableReader) { + ReaderWriterTestContext ctx; + PwpbServerReader responder = + PwpbServerReader:: + Open( + ctx.server, ctx.channel.id(), ctx.service); + + ASSERT_EQ(OkStatus(), responder.Finish({.chunk = {}, .number = 321})); + + EXPECT_EQ(ctx.output.last_response().number, + 321u); +} + +TEST(PwpbServerReaderWriter, Open_ReturnsUsableReaderWriter) { + ReaderWriterTestContext ctx; + PwpbServerReaderWriter responder = + PwpbServerReaderWriter:: + Open( + ctx.server, ctx.channel.id(), ctx.service); + + ASSERT_EQ(OkStatus(), responder.Write({.chunk = {}, .number = 321})); + ASSERT_EQ(OkStatus(), responder.Finish(Status::NotFound())); + + EXPECT_EQ(ctx.output.last_response() + .number, + 321u); + EXPECT_EQ(ctx.output.last_status(), Status::NotFound()); +} + +TEST(RawServerReaderWriter, Open_UnknownChannel) { + ReaderWriterTestContext ctx; + ASSERT_EQ(OkStatus(), ctx.server.CloseChannel(ctx.kChannelId)); + + PwpbServerReaderWriter call = + PwpbServerReaderWriter:: + Open( + ctx.server, ctx.kChannelId, ctx.service); + + EXPECT_TRUE(call.active()); + EXPECT_EQ(call.channel_id(), ctx.kChannelId); + EXPECT_EQ(Status::Unavailable(), call.Write({})); + + ASSERT_EQ(OkStatus(), ctx.server.OpenChannel(ctx.kChannelId, ctx.output)); + + EXPECT_EQ(OkStatus(), call.Write({})); + EXPECT_TRUE(call.active()); + + EXPECT_EQ(OkStatus(), call.Finish()); + EXPECT_FALSE(call.active()); + EXPECT_EQ(call.channel_id(), Channel::kUnassignedChannelId); +} + +TEST(PwpbServerReader, CallbacksMoveCorrectly) { + PW_PWPB_TEST_METHOD_CONTEXT(TestServiceImpl, TestClientStreamRpc) ctx; + + PwpbServerReader call_1 = ctx.reader(); + + ASSERT_TRUE(call_1.active()); + + test::TestRequest::Message received_request = {.integer = 12345678, + .status_code = 1}; + + call_1.set_on_next( + [&received_request](const test::TestRequest::Message& value) { + received_request = value; + }); + + PwpbServerReader + call_2; + call_2 = std::move(call_1); + + constexpr test::TestRequest::Message request{.integer = 600613, + .status_code = 2}; + ctx.SendClientStream(request); + EXPECT_EQ(request.integer, received_request.integer); + EXPECT_EQ(request.status_code, received_request.status_code); +} + +} // namespace pw::rpc diff --git a/pw_rpc/pwpb/stub_generation_test.cc b/pw_rpc/pwpb/stub_generation_test.cc new file mode 100644 index 0000000000..cf1919ff8a --- /dev/null +++ b/pw_rpc/pwpb/stub_generation_test.cc @@ -0,0 +1,29 @@ +// Copyright 2022 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +// This macro is used to remove the generated stubs from the proto files. Define +// so that the generated stubs can be tested. +#define _PW_RPC_COMPILE_GENERATED_SERVICE_STUBS + +#include "gtest/gtest.h" +#include "pw_rpc_test_protos/test.rpc.pwpb.h" + +namespace { + +TEST(PwpbServiceStub, GeneratedStubCompiles) { + ::pw::rpc::test::TestService test_service; + EXPECT_STREQ(test_service.name(), "TestService"); +} + +} // namespace diff --git a/pw_rpc/py/BUILD.bazel b/pw_rpc/py/BUILD.bazel index 4d9d454feb..bf5bdbef63 100644 --- a/pw_rpc/py/BUILD.bazel +++ b/pw_rpc/py/BUILD.bazel @@ -27,6 +27,7 @@ filegroup( "pw_rpc/callback_client/impl.py", "pw_rpc/codegen.py", "pw_rpc/codegen_nanopb.py", + "pw_rpc/codegen_pwpb.py", "pw_rpc/codegen_raw.py", "pw_rpc/console_tools/__init__.py", "pw_rpc/console_tools/console.py", @@ -37,6 +38,7 @@ filegroup( "pw_rpc/packets.py", "pw_rpc/plugin.py", "pw_rpc/plugin_nanopb.py", + "pw_rpc/plugin_pwpb.py", "pw_rpc/plugin_raw.py", ], ) @@ -69,6 +71,20 @@ py_binary( ], ) +py_binary( + name = "plugin_pwpb", + srcs = [":pw_rpc_common_sources"], + imports = ["."], + main = "pw_rpc/plugin_pwpb.py", + python_version = "PY3", + deps = [ + "//pw_protobuf/py:plugin_library", + "//pw_protobuf_compiler/py:pw_protobuf_compiler", + "//pw_status/py:pw_status", + "@com_google_protobuf//:protobuf_python", + ], +) + py_library( name = "pw_rpc", srcs = [ diff --git a/pw_rpc/py/BUILD.gn b/pw_rpc/py/BUILD.gn index 03c6ddf03d..3968fe281a 100644 --- a/pw_rpc/py/BUILD.gn +++ b/pw_rpc/py/BUILD.gn @@ -38,6 +38,7 @@ pw_python_package("py") { "pw_rpc/client.py", "pw_rpc/codegen.py", "pw_rpc/codegen_nanopb.py", + "pw_rpc/codegen_pwpb.py", "pw_rpc/codegen_raw.py", "pw_rpc/console_tools/__init__.py", "pw_rpc/console_tools/console.py", @@ -49,6 +50,7 @@ pw_python_package("py") { "pw_rpc/packets.py", "pw_rpc/plugin.py", "pw_rpc/plugin_nanopb.py", + "pw_rpc/plugin_pwpb.py", "pw_rpc/plugin_raw.py", "pw_rpc/testing.py", ] diff --git a/pw_rpc/py/pw_rpc/codegen_pwpb.py b/pw_rpc/py/pw_rpc/codegen_pwpb.py new file mode 100644 index 0000000000..df17090b06 --- /dev/null +++ b/pw_rpc/py/pw_rpc/codegen_pwpb.py @@ -0,0 +1,242 @@ +# Copyright 2022 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +"""This module generates the code for pw_protobuf pw_rpc services.""" + +import os +from typing import Iterable + +from pw_protobuf.output_file import OutputFile +from pw_protobuf.proto_tree import ProtoServiceMethod +from pw_protobuf.proto_tree import build_node_tree +from pw_rpc import codegen +from pw_rpc.codegen import (client_call_type, get_id, CodeGenerator, + RPC_NAMESPACE) + +PROTO_H_EXTENSION = '.pwpb.h' +PWPB_H_EXTENSION = '.pwpb.h' + + +def _proto_filename_to_pwpb_header(proto_file: str) -> str: + """Returns the generated pwpb header name for a .proto file.""" + filename = os.path.splitext(proto_file)[0] + return f'{filename}{PWPB_H_EXTENSION}' + + +def _proto_filename_to_generated_header(proto_file: str) -> str: + """Returns the generated C++ RPC header name for a .proto file.""" + filename = os.path.splitext(proto_file)[0] + return f'{filename}.rpc{PROTO_H_EXTENSION}' + + +def _serde(method: ProtoServiceMethod) -> str: + """Returns the PwpbMethodSerde for this method.""" + return (f'{RPC_NAMESPACE}::internal::kPwpbMethodSerde<' + f'&{method.request_type().pwpb_table()}, ' + f'&{method.response_type().pwpb_table()}>') + + +def _client_call(method: ProtoServiceMethod) -> str: + template_args = [] + + if method.client_streaming(): + template_args.append(method.request_type().pwpb_struct()) + + template_args.append(method.response_type().pwpb_struct()) + + return f'{client_call_type(method, "Pwpb")}<{", ".join(template_args)}>' + + +def _function(method: ProtoServiceMethod) -> str: + return f'{_client_call(method)} {method.name()}' + + +def _user_args(method: ProtoServiceMethod) -> Iterable[str]: + if not method.client_streaming(): + yield f'const {method.request_type().pwpb_struct()}& request' + + response = method.response_type().pwpb_struct() + + if method.server_streaming(): + yield f'::pw::Function&& on_next = nullptr' + yield '::pw::Function&& on_completed = nullptr' + else: + yield (f'::pw::Function&& ' + 'on_completed = nullptr') + + yield '::pw::Function&& on_error = nullptr' + + +class PwpbCodeGenerator(CodeGenerator): + """Generates an RPC service and client using the pw_protobuf API.""" + def name(self) -> str: + return 'pwpb' + + def method_union_name(self) -> str: + return 'PwpbMethodUnion' + + def includes(self, proto_file_name: str) -> Iterable[str]: + yield '#include "pw_rpc/pwpb/client_reader_writer.h"' + yield '#include "pw_rpc/pwpb/internal/method_union.h"' + yield '#include "pw_rpc/pwpb/server_reader_writer.h"' + + # Include the corresponding pwpb header file for this proto file, in + # which the file's messages and enums are generated. All other files + # imported from the .proto file are #included in there. + pwpb_header = _proto_filename_to_pwpb_header(proto_file_name) + yield f'#include "{pwpb_header}"' + + def service_aliases(self) -> None: + self.line('template ') + self.line('using ServerWriter = ' + f'{RPC_NAMESPACE}::PwpbServerWriter;') + self.line('template ') + self.line('using ServerReader = ' + f'{RPC_NAMESPACE}::PwpbServerReader;') + self.line('template ') + self.line( + 'using ServerReaderWriter = ' + f'{RPC_NAMESPACE}::PwpbServerReaderWriter;') + + def method_descriptor(self, method: ProtoServiceMethod) -> None: + impl_method = f'&Implementation::{method.name()}' + + self.line( + f'{RPC_NAMESPACE}::internal::GetPwpbOrRawMethodFor<{impl_method}, ' + f'{method.type().cc_enum()}, ' + f'{method.request_type().pwpb_struct()}, ' + f'{method.response_type().pwpb_struct()}>(') + with self.indent(4): + self.line(f'{get_id(method)}, // Hash of "{method.name()}"') + self.line(f'{_serde(method)}),') + + def client_member_function(self, method: ProtoServiceMethod) -> None: + """Outputs client code for a single RPC method.""" + + self.line(f'{_function(method)}(') + self.indented_list(*_user_args(method), end=') const {') + + with self.indent(): + client_call = _client_call(method) + base = 'Stream' if method.server_streaming() else 'Unary' + self.line(f'return {RPC_NAMESPACE}::internal::' + f'Pwpb{base}ResponseClientCall<' + f'{method.response_type().pwpb_struct()}>::' + f'Start<{client_call}>(') + + service_client = RPC_NAMESPACE + '::internal::ServiceClient' + + args = [ + f'{service_client}::client()', + f'{service_client}::channel_id()', + 'kServiceId', + get_id(method), + _serde(method), + ] + if method.server_streaming(): + args.append('std::move(on_next)') + + args.append('std::move(on_completed)') + args.append('std::move(on_error)') + + if not method.client_streaming(): + args.append('request') + + self.indented_list(*args, end=');') + + self.line('}') + + def client_static_function(self, method: ProtoServiceMethod) -> None: + self.line(f'static {_function(method)}(') + self.indented_list(f'{RPC_NAMESPACE}::Client& client', + 'uint32_t channel_id', + *_user_args(method), + end=') {') + + with self.indent(): + self.line(f'return Client(client, channel_id).{method.name()}(') + + args = [] + + if not method.client_streaming(): + args.append('request') + + if method.server_streaming(): + args.append('std::move(on_next)') + + self.indented_list(*args, + 'std::move(on_completed)', + 'std::move(on_error)', + end=');') + + self.line('}') + + def method_info_specialization(self, method: ProtoServiceMethod) -> None: + self.line() + self.line(f'using Request = {method.request_type().pwpb_struct()};') + self.line(f'using Response = {method.response_type().pwpb_struct()};') + self.line() + self.line(f'static constexpr const {RPC_NAMESPACE}::internal::' + 'PwpbMethodSerde& serde() {') + with self.indent(): + self.line(f'return {_serde(method)};') + self.line('}') + + +class StubGenerator(codegen.StubGenerator): + """Generates pw_protobuf RPC stubs.""" + def unary_signature(self, method: ProtoServiceMethod, prefix: str) -> str: + return (f'::pw::Status {prefix}{method.name()}( ' + f'const {method.request_type().pwpb_struct()}& request, ' + f'{method.response_type().pwpb_struct()}& response)') + + def unary_stub(self, method: ProtoServiceMethod, + output: OutputFile) -> None: + output.write_line(codegen.STUB_REQUEST_TODO) + output.write_line('static_cast(request);') + output.write_line(codegen.STUB_RESPONSE_TODO) + output.write_line('static_cast(response);') + output.write_line('return ::pw::Status::Unimplemented();') + + def server_streaming_signature(self, method: ProtoServiceMethod, + prefix: str) -> str: + return ( + f'void {prefix}{method.name()}( ' + f'const {method.request_type().pwpb_struct()}& request, ' + f'ServerWriter<{method.response_type().pwpb_struct()}>& writer)') + + def client_streaming_signature(self, method: ProtoServiceMethod, + prefix: str) -> str: + return (f'void {prefix}{method.name()}( ' + f'ServerReader<{method.request_type().pwpb_struct()}, ' + f'{method.response_type().pwpb_struct()}>& reader)') + + def bidirectional_streaming_signature(self, method: ProtoServiceMethod, + prefix: str) -> str: + return (f'void {prefix}{method.name()}( ' + f'ServerReaderWriter<{method.request_type().pwpb_struct()}, ' + f'{method.response_type().pwpb_struct()}>& reader_writer)') + + +def process_proto_file(proto_file) -> Iterable[OutputFile]: + """Generates code for a single .proto file.""" + + _, package_root = build_node_tree(proto_file) + output_filename = _proto_filename_to_generated_header(proto_file.name) + + generator = PwpbCodeGenerator(output_filename) + codegen.generate_package(proto_file, package_root, generator) + + codegen.package_stubs(package_root, generator, StubGenerator()) + + return [generator.output] diff --git a/pw_rpc/py/pw_rpc/plugin.py b/pw_rpc/py/pw_rpc/plugin.py index 30d616045f..1042748712 100644 --- a/pw_rpc/py/pw_rpc/plugin.py +++ b/pw_rpc/py/pw_rpc/plugin.py @@ -19,12 +19,14 @@ from google.protobuf.compiler import plugin_pb2 from pw_rpc import codegen_nanopb +from pw_rpc import codegen_pwpb from pw_rpc import codegen_raw class Codegen(enum.Enum): RAW = 0 NANOPB = 1 + PWPB = 2 def process_proto_request(codegen: Codegen, @@ -44,6 +46,8 @@ def process_proto_request(codegen: Codegen, output_files = codegen_raw.process_proto_file(proto_file) elif codegen is Codegen.NANOPB: output_files = codegen_nanopb.process_proto_file(proto_file) + elif codegen is Codegen.PWPB: + output_files = codegen_pwpb.process_proto_file(proto_file) else: raise NotImplementedError(f'Unknown codegen type {codegen}') diff --git a/pw_rpc/py/pw_rpc/plugin_pwpb.py b/pw_rpc/py/pw_rpc/plugin_pwpb.py new file mode 100755 index 0000000000..a99163d668 --- /dev/null +++ b/pw_rpc/py/pw_rpc/plugin_pwpb.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# Copyright 2022 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +"""pw_rpc pw_protobuf protoc plugin.""" + +import sys + +from pw_rpc import plugin + + +def main() -> int: + return plugin.main(plugin.Codegen.PWPB) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/third_party/rules_proto_grpc/internal_proto.bzl b/third_party/rules_proto_grpc/internal_proto.bzl index cab19d4050..1331a0689a 100644 --- a/third_party/rules_proto_grpc/internal_proto.bzl +++ b/third_party/rules_proto_grpc/internal_proto.bzl @@ -96,6 +96,21 @@ pwpb_compile = _proto_compiler_rule( pwpb_compile_aspect, ) +pwpb_rpc_compile_aspect = _proto_compiler_aspect( + [ + Label("@pigweed//pw_rpc:pw_cc_plugin_pwpb_rpc"), + Label("@pigweed//pw_protobuf:pw_cc_plugin"), + ], + "pwpb_rpc_proto_compile_aspect", +) +pwpb_rpc_compile = _proto_compiler_rule( + [ + Label("@pigweed//pw_rpc:pw_cc_plugin_pwpb_rpc"), + Label("@pigweed//pw_protobuf:pw_cc_plugin"), + ], + pwpb_rpc_compile_aspect, +) + raw_rpc_compile_aspect = _proto_compiler_aspect( [Label("@pigweed//pw_rpc:pw_cc_plugin_raw")], "raw_rpc_proto_compile_aspect", @@ -155,6 +170,17 @@ PLUGIN_INFO = { "has_srcs": False, "additional_tags": [], }, + "pwpb_rpc": { + "compiler": pwpb_rpc_compile, + "deps": [ + "@pigweed//pw_protobuf", + "@pigweed//pw_rpc/pwpb:server_api", + "@pigweed//pw_rpc/pwpb:client_api", + "@pigweed//pw_rpc", + ], + "has_srcs": False, + "additional_tags": [], + }, "raw_rpc": { "compiler": raw_rpc_compile, "deps": [