From d1d74fe17c0cb5e25183b3b71bc642ac1aa38a5a Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Wed, 14 Sep 2022 15:09:31 -0700 Subject: [PATCH 01/30] Initial runtime support for middlewares --- .../aws-smithy-http-server-python/Cargo.toml | 4 + .../examples/pokemon_service.py | 19 + .../aws-smithy-http-server-python/src/lib.rs | 3 + .../src/middleware.rs | 501 ++++++++++++++++++ .../src/server.rs | 173 +++++- 5 files changed, 692 insertions(+), 8 deletions(-) create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware.rs diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index 7106b30c4a..e05c100299 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -18,9 +18,12 @@ aws-smithy-types = { path = "../aws-smithy-types" } aws-smithy-http = { path = "../aws-smithy-http" } bytes = "1.2" futures = "0.3" +futures-core = "0.3" +http = "0.2" hyper = { version = "0.14.20", features = ["server", "http1", "http2", "tcp", "stream"] } num_cpus = "1.13.1" parking_lot = "0.12.1" +pin-project-lite = "0.2" pyo3 = "0.16.5" pyo3-asyncio = { version = "0.16.0", features = ["tokio-runtime"] } signal-hook = { version = "0.3.14", features = ["extended-siginfo"] } @@ -34,6 +37,7 @@ tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } [dev-dependencies] pretty_assertions = "1" +futures-util = "0.3" [package.metadata.docs.rs] all-features = true diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index cfad7c1309..1a6584072e 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -11,8 +11,10 @@ from typing import List, Optional import aiohttp + from libpokemon_service_server_sdk import App from libpokemon_service_server_sdk.error import ResourceNotFoundException +from libpokemon_service_server_sdk.http import Request from libpokemon_service_server_sdk.input import ( EmptyOperationInput, GetPokemonSpeciesInput, GetServerStatisticsInput, HealthCheckOperationInput, StreamPokemonRadioOperationInput) @@ -109,6 +111,23 @@ def get_random_radio_stream(self) -> str: app = App() # Register the context. app.context(Context()) +# Register a middleware. + + +########################################################### +# Middleware +########################################################### +@app.middleware +def check_header(request: Request): + logging.info("Inside MID1") + logging.info(request) + + +@app.middleware +def check_header2(request: Request): + logging.info("Inside MID2") + logging.info(request) + raise ValueError("Lol") ########################################################### diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 1f5c738194..b263a5d54b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -13,6 +13,7 @@ mod error; pub mod logging; +mod middleware; mod server; mod socket; pub mod types; @@ -22,6 +23,8 @@ pub use error::Error; #[doc(inline)] pub use logging::LogLevel; #[doc(inline)] +pub use middleware::{PyMiddlewareHandler, PyRequest, PyMiddlewareHandlers, PyMiddleware, middleware_wrapper}; +#[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] pub use socket::PySocket; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware.rs new file mode 100644 index 0000000000..2b148278ca --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware.rs @@ -0,0 +1,501 @@ +//! Authorize requests using the [`Authorization`] header asynchronously. +//! +//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +//! +//! # Example +//! +//! ``` +//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; +//! use hyper::{Request, Response, Body, Error}; +//! use http::{StatusCode, header::AUTHORIZATION}; +//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; +//! use futures_util::future::BoxFuture; +//! +//! #[derive(Clone, Copy)] +//! struct MyAuth; +//! +//! impl AsyncAuthorizeRequest for MyAuth +//! where +//! B: Send + Sync + 'static, +//! { +//! type RequestBody = B; +//! type ResponseBody = Body; +//! type Future = BoxFuture<'static, Result, Response>>; +//! +//! fn authorize(&mut self, mut request: Request) -> Self::Future { +//! Box::pin(async { +//! if let Some(user_id) = check_auth(&request).await { +//! // Set `user_id` as a request extension so it can be accessed by other +//! // services down the stack. +//! request.extensions_mut().insert(user_id); +//! +//! Ok(request) +//! } else { +//! let unauthorized_response = Response::builder() +//! .status(StatusCode::UNAUTHORIZED) +//! .body(Body::empty()) +//! .unwrap(); +//! +//! Err(unauthorized_response) +//! } +//! }) +//! } +//! } +//! +//! async fn check_auth(request: &Request) -> Option { +//! // ... +//! # None +//! } +//! +//! #[derive(Debug)] +//! struct UserId(String); +//! +//! async fn handle(request: Request) -> Result, Error> { +//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the +//! // request was authorized and `UserId` will be present. +//! let user_id = request +//! .extensions() +//! .get::() +//! .expect("UserId will be there if request was authorized"); +//! +//! println!("request from {:?}", user_id); +//! +//! Ok(Response::new(Body::empty())) +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let service = ServiceBuilder::new() +//! // Authorize requests using `MyAuth` +//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` +//! +//! Or using a closure: +//! +//! ``` +//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; +//! use hyper::{Request, Response, Body, Error}; +//! use http::StatusCode; +//! use tower::{Service, ServiceExt, ServiceBuilder}; +//! use futures_util::future::BoxFuture; +//! +//! async fn check_auth(request: &Request) -> Option { +//! // ... +//! # None +//! } +//! +//! #[derive(Debug)] +//! struct UserId(String); +//! +//! async fn handle(request: Request) -> Result, Error> { +//! # todo!(); +//! // ... +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> Result<(), Box> { +//! let service = ServiceBuilder::new() +//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request| async move { +//! if let Some(user_id) = check_auth(&request).await { +//! Ok(request) +//! } else { +//! let unauthorized_response = Response::builder() +//! .status(StatusCode::UNAUTHORIZED) +//! .body(Body::empty()) +//! .unwrap(); +//! +//! Err(unauthorized_response) +//! } +//! })) +//! .service_fn(handle); +//! # Ok(()) +//! # } +//! ``` + +use aws_smithy_http_server::body::{boxed, BoxBody}; +use bytes::Bytes; +use futures::future::BoxFuture; +use futures_core::ready; +use http::{Request, Response}; +use hyper::Body; +use pin_project_lite::pin_project; +use pyo3::{pyclass, IntoPy, PyAny, PyErr, PyObject, PyResult}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, error::Error, +}; +use tower::{Layer, Service}; + +#[pyclass(name = "Request")] +#[derive(Debug)] +pub struct PyRequest(Request); + +impl PyRequest { + pub fn new(request: &Request) -> Self { + let mut self_ = Request::builder() + .uri(request.uri()) + .method(request.method()) + .body(Bytes::new()) + .unwrap(); + let headers = self_.headers_mut(); + *headers = request.headers().clone(); + Self(self_) + } + + pub fn new_with_body(request: &Request) -> Self { + let mut self_ = Request::builder() + .uri(request.uri()) + .method(request.method()) + .body(Bytes::new()) + .unwrap(); + let headers = self_.headers_mut(); + *headers = request.headers().clone(); + Self(self_) + } +} + +impl Clone for PyRequest { + fn clone(&self) -> Self { + let mut request = Request::builder() + .uri(self.0.uri()) + .method(self.0.method()) + .body(self.0.body().clone()) + .unwrap(); + let headers = request.headers_mut(); + *headers = self.0.headers().clone(); + Self(request) + } +} + +/// A Python handler function representation. +/// +/// The Python business logic implementation needs to carry some information +/// to be executed properly like the size of its arguments and if it is +/// a coroutine. +#[derive(Debug, Clone)] +pub struct PyMiddlewareHandler { + pub name: String, + pub func: PyObject, + pub is_coroutine: bool, + pub with_body: bool, +} + +#[derive(Debug, Clone)] +pub struct PyMiddlewareHandlers(pub Vec); + +// Our request handler. This is where we would implement the application logic +// for responding to HTTP requests... +pub async fn middleware_wrapper( + request: PyRequest, + handler: PyMiddlewareHandler, +) -> PyResult<()> { + let result = if handler.is_coroutine { + tracing::debug!("Executing Python handler coroutine `stream_pokemon_radio_operation()`"); + let result = pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + let coroutine = pyhandler.call1((request,))?; + pyo3_asyncio::tokio::into_future(coroutine) + })?; + result + .await + .map(|_| ()) + } else { + tracing::debug!("Executing Python handler function `stream_pokemon_radio_operation()`"); + tokio::task::block_in_place(move || { + pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + pyhandler.call1((request,))?; + Ok(()) + }) + }) + }; + // Catch and record a Python traceback. + result.map_err(|e| { + let traceback = pyo3::Python::with_gil(|py| match e.traceback(py) { + Some(t) => t.format().unwrap_or_else(|e| e.to_string()), + None => "Unknown traceback\n".to_string(), + }); + tracing::error!("{}{}", traceback, e); + e + })?; + Ok(()) +} + +impl PyMiddleware for PyMiddlewareHandlers +where + B: Send + Sync + 'static, +{ + type RequestBody = B; + type ResponseBody = BoxBody; + type Future = BoxFuture<'static, Result, Response>>; + + fn run(&mut self, request: Request) -> Self::Future { + let handlers = self.0.clone(); + Box::pin(async move { + for handler in handlers { + let pyrequest = if handler.with_body { + PyRequest::new_with_body(&request) + } else { + PyRequest::new(&request) + }; + middleware_wrapper(pyrequest, handler).await.map_err(|e| into_response(e))?; + } + Ok(request) + }) + } +} + +fn into_response(error: T) -> Response { + Response::builder().status(500).body(boxed(error.to_string())).unwrap() + +} + +/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the +/// [`Authorization`] header. +/// +/// See the [module docs](crate::auth::async_require_authorization) for an example. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +#[derive(Debug, Clone)] +pub struct PyMiddlewareLayer { + handler: T, +} + +impl PyMiddlewareLayer { + /// Authorize requests using a custom scheme. + pub fn new(handler: T) -> PyMiddlewareLayer { + Self { handler } + } +} + +impl Layer for PyMiddlewareLayer +where + T: Clone, +{ + type Service = PyMiddlewareService; + + fn layer(&self, inner: S) -> Self::Service { + PyMiddlewareService::new(inner, self.handler.clone()) + } +} + +/// Middleware that authorizes all requests using the [`Authorization`] header. +/// +/// See the [module docs](crate::auth::async_require_authorization) for an example. +/// +/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization +#[derive(Clone, Debug)] +pub struct PyMiddlewareService { + inner: S, + handler: T, +} + +impl PyMiddlewareService { + /// Authorize requests using a custom scheme. + /// + /// The `Authorization` header is required to have the value provided. + pub fn new(inner: S, handler: T) -> PyMiddlewareService { + Self { inner, handler } + } + + /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`] + /// middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(handler: T) -> PyMiddlewareLayer { + PyMiddlewareLayer::new(handler) + } +} + +impl Service> for PyMiddlewareService +where + M: PyMiddleware, + S: Service, Response = Response> + Clone, +{ + type Response = Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let inner = self.inner.clone(); + let run = self.handler.run(req); + + ResponseFuture { + middleware: State::Run { run }, + service: inner, + } + } +} + +pin_project! { + /// Response future for [`AsyncRequireAuthorization`]. + pub struct ResponseFuture + where + M: PyMiddleware, + S: Service>, + { + #[pin] + middleware: State, + service: S, + } +} + +pin_project! { + #[project = StateProj] + enum State { + Run { + #[pin] + run: A, + }, + Done { + #[pin] + fut: Fut + } + } +} + +impl Future for ResponseFuture +where + M: PyMiddleware, + S: Service, Response = Response>, +{ + type Output = Result, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + loop { + match this.middleware.as_mut().project() { + StateProj::Run { run } => { + let run = ready!(run.poll(cx)); + match run { + Ok(req) => { + let fut = this.service.call(req); + this.middleware.set(State::Done { fut }); + } + Err(res) => return Poll::Ready(Ok(res)), + } + } + StateProj::Done { fut } => return fut.poll(cx), + } + } + } +} + +/// Trait for authorizing requests. +pub trait PyMiddleware { + type RequestBody; + type ResponseBody; + /// The Future type returned by `authorize` + type Future: Future, Response>>; + + /// Authorize the request. + /// + /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not. + fn run(&mut self, request: Request) -> Self::Future; +} + +impl PyMiddleware for F +where + F: FnMut(Request) -> Fut, + Fut: Future, Response>>, +{ + type RequestBody = ReqBody; + type ResponseBody = ResBody; + type Future = Fut; + + fn run(&mut self, request: Request) -> Self::Future { + self(request) + } +} + +#[cfg(test)] +mod tests { + #[allow(unused_imports)] + use super::*; + use futures_util::future::BoxFuture; + use http::{header, StatusCode}; + use hyper::Body; + use tower::{BoxError, ServiceBuilder, ServiceExt}; + + #[derive(Clone, Copy)] + struct MyAuth; + + impl PyMiddleware for MyAuth + where + B: Send + 'static, + { + type RequestBody = B; + type ResponseBody = Body; + type Future = BoxFuture<'static, Result, Response>>; + + fn authorize(&mut self, mut request: Request) -> Self::Future { + Box::pin(async move { + let authorized = request + .headers() + .get(header::AUTHORIZATION) + .and_then(|it| it.to_str().ok()) + .and_then(|it| it.strip_prefix("Bearer ")) + .map(|it| it == "69420") + .unwrap_or(false); + + if authorized { + let user_id = UserId("6969".to_owned()); + request.extensions_mut().insert(user_id); + Ok(request) + } else { + Err(Response::builder() + .status(StatusCode::UNAUTHORIZED) + .body(Body::empty()) + .unwrap()) + } + }) + } + } + + #[derive(Debug)] + struct UserId(String); + + #[tokio::test] + async fn require_async_auth_works() { + let mut service = ServiceBuilder::new() + .layer(PyMiddlewareLayer::new(MyAuth)) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer 69420") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::OK); + } + + #[tokio::test] + async fn require_async_auth_401() { + let mut service = ServiceBuilder::new() + .layer(PyMiddlewareLayer::new(MyAuth)) + .service_fn(echo); + + let request = Request::get("/") + .header(header::AUTHORIZATION, "Bearer deez") + .body(Body::empty()) + .unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); + } + + async fn echo(req: Request) -> Result, BoxError> { + Ok(Response::new(req.into_body())) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 7eacb154cc..bcda8dd388 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -4,16 +4,30 @@ */ // Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT. -use std::{collections::HashMap, ops::Deref, process, thread}; +use std::{ + collections::HashMap, + ops::Deref, + pin::Pin, + process, + task::{Context, Poll}, + thread, +}; use aws_smithy_http_server::{AddExtensionLayer, Router}; +use bytes::Bytes; +use futures::Future; +use http::{request::Parts, Request, Response}; use parking_lot::Mutex; -use pyo3::{prelude::*, types::IntoPyDict}; +use pin_project_lite::pin_project; +use pyo3::{exceptions::PyException, prelude::*, types::IntoPyDict}; use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; -use tower::ServiceBuilder; +use tower::{util::ServiceFn, Layer, Service, ServiceBuilder}; -use crate::PySocket; +use crate::{ + middleware::{PyMiddlewareHandler, PyMiddlewareLayer}, + PySocket, PyMiddlewareHandlers, +}; /// A Python handler function representation. /// @@ -61,6 +75,8 @@ pub trait PyApp: Clone + pyo3::IntoPy { /// Mapping between operation names and their `PyHandler` representation. fn handlers(&mut self) -> &mut HashMap; + fn middlewares(&mut self) -> &mut Vec; + /// Handle the graceful termination of Python workers by looping through all the /// active workers and calling `terminate()` on them. If termination fails, this /// method will try to `kill()` any failed worker. @@ -157,11 +173,11 @@ pub trait PyApp: Clone + pyo3::IntoPy { py.run( r#" import asyncio -import logging import functools import signal async def shutdown(sig, event_loop): + import logging logging.info(f"Caught signal {sig.name}, cancelling tasks registered on this loop") tasks = [task for task in asyncio.all_tasks() if task is not asyncio.current_task()] @@ -216,6 +232,7 @@ event_loop.add_signal_handler(signal.SIGINT, // Register signals on the Python event loop. self.register_python_signals(py, event_loop.to_object(py))?; + let middlewares = PyMiddlewareHandlers(self.middlewares().clone()); // Spawn a new background [std::thread] to run the application. tracing::debug!("Start the Tokio runtime in a background task"); thread::spawn(move || { @@ -229,8 +246,10 @@ event_loop.add_signal_handler(signal.SIGINT, // all inside a [tokio] blocking function. rt.block_on(async move { tracing::debug!("Add middlewares to Rust Python router"); - let app = - router.layer(ServiceBuilder::new().layer(AddExtensionLayer::new(context))); + let service = ServiceBuilder::new() + .layer(AddExtensionLayer::new(context)) + .layer(PyMiddlewareLayer::new(middlewares)); + let app = router.layer(service); let server = hyper::Server::from_tcp( raw_socket .try_into() @@ -252,6 +271,31 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } + fn register_middleware(&mut self, py: Python, func: PyObject, with_body: bool) -> PyResult<()> { + let inspect = py.import("inspect")?; + // Check if the function is a coroutine. + // NOTE: that `asyncio.iscoroutine()` doesn't work here. + let is_coroutine = inspect + .call_method1("iscoroutinefunction", (&func,))? + .extract::()?; + let name = func.getattr(py, "__name__")?.extract::(py)?; + // Find number of expected methods (a Pythzzon implementation could not accept the context). + let handler = PyMiddlewareHandler { + name, + func, + is_coroutine, + with_body, + }; + tracing::info!( + "Registering middleware function `{}`, coroutine: {}, with_body: {}", + handler.name, + handler.is_coroutine, + handler.with_body, + ); + self.middlewares().push(handler); + Ok(()) + } + /// Register a Python function to be executed inside the Smithy Rust handler. /// /// There are some information needed to execute the Python code from a Rust handler, @@ -276,7 +320,7 @@ event_loop.add_signal_handler(signal.SIGINT, args: func_args.len(), }; tracing::info!( - "Registering function `{name}`, coroutine: {}, arguments: {}", + "Registering handler function `{name}`, coroutine: {}, arguments: {}", handler.is_coroutine, handler.args, ); @@ -410,3 +454,116 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } } + +#[pyclass] +#[derive(Debug)] +struct PyRequest(Request); + +impl Clone for PyRequest { + fn clone(&self) -> Self { + let mut request = Request::builder() + .uri(self.0.uri()) + .method(self.0.method()) + .body(self.0.body().clone()) + .unwrap(); + let headers = request.headers_mut(); + *headers = self.0.headers().clone(); + Self(request) + } +} + +#[pyclass] +#[derive(Debug)] +struct PyResponse(Response); + +impl Clone for PyResponse { + fn clone(&self) -> Self { + let mut response = Response::builder() + .status(self.0.status()) + .body(self.0.body().clone()) + .unwrap(); + let headers = response.headers_mut(); + *headers = self.0.headers().clone(); + Self(response) + } +} + +#[derive(Debug, Clone)] +pub struct PyMiddleware { + handler: PyHandler, +} + +impl PyMiddleware { + /// Create a new [`TimeoutLayer`]. + pub fn new(handler: PyHandler) -> Self { + PyMiddleware { handler } + } +} + +impl Layer for PyMiddleware { + type Service = Middleware; + + fn layer(&self, inner: S) -> Self::Service { + Middleware::new(inner, self.handler.clone()) + } +} + +#[derive(Debug, Clone)] +pub struct Middleware { + inner: S, + handler: PyHandler, +} + +impl Middleware { + /// Create a new [`Timeout`]. + pub fn new(inner: S, handler: PyHandler) -> Self { + Self { inner, handler } + } + + /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. + /// + /// [`Layer`]: tower_layer::Layer + pub fn layer(handler: PyHandler) -> PyMiddleware { + PyMiddleware::new(handler) + } +} + +impl Service> for Middleware +where + S: Service, Response = Response>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = ResponseFuture; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let response = self.inner.call(req); + ResponseFuture { inner: response } + } +} + +pin_project! { + /// Response future for [`Timeout`]. + pub struct ResponseFuture { + #[pin] + inner: F, + } +} + +impl Future for ResponseFuture +where + F: Future, E>>, + B: Default, +{ + type Output = Result, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + this.inner.poll(cx) + } +} From da9e05249652492b80765ec082380e5fd0bfcb8e Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Wed, 14 Sep 2022 15:12:58 -0700 Subject: [PATCH 02/30] Cleanup --- .../src/middleware.rs | 234 +----------------- .../src/server.rs | 138 +---------- 2 files changed, 16 insertions(+), 356 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware.rs index 2b148278ca..beef505d66 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware.rs @@ -1,120 +1,3 @@ -//! Authorize requests using the [`Authorization`] header asynchronously. -//! -//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization -//! -//! # Example -//! -//! ``` -//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; -//! use hyper::{Request, Response, Body, Error}; -//! use http::{StatusCode, header::AUTHORIZATION}; -//! use tower::{Service, ServiceExt, ServiceBuilder, service_fn}; -//! use futures_util::future::BoxFuture; -//! -//! #[derive(Clone, Copy)] -//! struct MyAuth; -//! -//! impl AsyncAuthorizeRequest for MyAuth -//! where -//! B: Send + Sync + 'static, -//! { -//! type RequestBody = B; -//! type ResponseBody = Body; -//! type Future = BoxFuture<'static, Result, Response>>; -//! -//! fn authorize(&mut self, mut request: Request) -> Self::Future { -//! Box::pin(async { -//! if let Some(user_id) = check_auth(&request).await { -//! // Set `user_id` as a request extension so it can be accessed by other -//! // services down the stack. -//! request.extensions_mut().insert(user_id); -//! -//! Ok(request) -//! } else { -//! let unauthorized_response = Response::builder() -//! .status(StatusCode::UNAUTHORIZED) -//! .body(Body::empty()) -//! .unwrap(); -//! -//! Err(unauthorized_response) -//! } -//! }) -//! } -//! } -//! -//! async fn check_auth(request: &Request) -> Option { -//! // ... -//! # None -//! } -//! -//! #[derive(Debug)] -//! struct UserId(String); -//! -//! async fn handle(request: Request) -> Result, Error> { -//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the -//! // request was authorized and `UserId` will be present. -//! let user_id = request -//! .extensions() -//! .get::() -//! .expect("UserId will be there if request was authorized"); -//! -//! println!("request from {:?}", user_id); -//! -//! Ok(Response::new(Body::empty())) -//! } -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { -//! let service = ServiceBuilder::new() -//! // Authorize requests using `MyAuth` -//! .layer(AsyncRequireAuthorizationLayer::new(MyAuth)) -//! .service_fn(handle); -//! # Ok(()) -//! # } -//! ``` -//! -//! Or using a closure: -//! -//! ``` -//! use tower_http::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest}; -//! use hyper::{Request, Response, Body, Error}; -//! use http::StatusCode; -//! use tower::{Service, ServiceExt, ServiceBuilder}; -//! use futures_util::future::BoxFuture; -//! -//! async fn check_auth(request: &Request) -> Option { -//! // ... -//! # None -//! } -//! -//! #[derive(Debug)] -//! struct UserId(String); -//! -//! async fn handle(request: Request) -> Result, Error> { -//! # todo!(); -//! // ... -//! } -//! -//! # #[tokio::main] -//! # async fn main() -> Result<(), Box> { -//! let service = ServiceBuilder::new() -//! .layer(AsyncRequireAuthorizationLayer::new(|request: Request| async move { -//! if let Some(user_id) = check_auth(&request).await { -//! Ok(request) -//! } else { -//! let unauthorized_response = Response::builder() -//! .status(StatusCode::UNAUTHORIZED) -//! .body(Body::empty()) -//! .unwrap(); -//! -//! Err(unauthorized_response) -//! } -//! })) -//! .service_fn(handle); -//! # Ok(()) -//! # } -//! ``` - use aws_smithy_http_server::body::{boxed, BoxBody}; use bytes::Bytes; use futures::future::BoxFuture; @@ -124,9 +7,10 @@ use hyper::Body; use pin_project_lite::pin_project; use pyo3::{pyclass, IntoPy, PyAny, PyErr, PyObject, PyResult}; use std::{ + error::Error, future::Future, pin::Pin, - task::{Context, Poll}, error::Error, + task::{Context, Poll}, }; use tower::{Layer, Service}; @@ -171,11 +55,6 @@ impl Clone for PyRequest { } } -/// A Python handler function representation. -/// -/// The Python business logic implementation needs to carry some information -/// to be executed properly like the size of its arguments and if it is -/// a coroutine. #[derive(Debug, Clone)] pub struct PyMiddlewareHandler { pub name: String, @@ -189,10 +68,7 @@ pub struct PyMiddlewareHandlers(pub Vec); // Our request handler. This is where we would implement the application logic // for responding to HTTP requests... -pub async fn middleware_wrapper( - request: PyRequest, - handler: PyMiddlewareHandler, -) -> PyResult<()> { +pub async fn middleware_wrapper(request: PyRequest, handler: PyMiddlewareHandler) -> PyResult<()> { let result = if handler.is_coroutine { tracing::debug!("Executing Python handler coroutine `stream_pokemon_radio_operation()`"); let result = pyo3::Python::with_gil(|py| { @@ -200,9 +76,7 @@ pub async fn middleware_wrapper( let coroutine = pyhandler.call1((request,))?; pyo3_asyncio::tokio::into_future(coroutine) })?; - result - .await - .map(|_| ()) + result.await.map(|_| ()) } else { tracing::debug!("Executing Python handler function `stream_pokemon_radio_operation()`"); tokio::task::block_in_place(move || { @@ -242,7 +116,9 @@ where } else { PyRequest::new(&request) }; - middleware_wrapper(pyrequest, handler).await.map_err(|e| into_response(e))?; + middleware_wrapper(pyrequest, handler) + .await + .map_err(|e| into_response(e))?; } Ok(request) }) @@ -250,16 +126,12 @@ where } fn into_response(error: T) -> Response { - Response::builder().status(500).body(boxed(error.to_string())).unwrap() - + Response::builder() + .status(500) + .body(boxed(error.to_string())) + .unwrap() } -/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the -/// [`Authorization`] header. -/// -/// See the [module docs](crate::auth::async_require_authorization) for an example. -/// -/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization #[derive(Debug, Clone)] pub struct PyMiddlewareLayer { handler: T, @@ -415,87 +287,3 @@ where self(request) } } - -#[cfg(test)] -mod tests { - #[allow(unused_imports)] - use super::*; - use futures_util::future::BoxFuture; - use http::{header, StatusCode}; - use hyper::Body; - use tower::{BoxError, ServiceBuilder, ServiceExt}; - - #[derive(Clone, Copy)] - struct MyAuth; - - impl PyMiddleware for MyAuth - where - B: Send + 'static, - { - type RequestBody = B; - type ResponseBody = Body; - type Future = BoxFuture<'static, Result, Response>>; - - fn authorize(&mut self, mut request: Request) -> Self::Future { - Box::pin(async move { - let authorized = request - .headers() - .get(header::AUTHORIZATION) - .and_then(|it| it.to_str().ok()) - .and_then(|it| it.strip_prefix("Bearer ")) - .map(|it| it == "69420") - .unwrap_or(false); - - if authorized { - let user_id = UserId("6969".to_owned()); - request.extensions_mut().insert(user_id); - Ok(request) - } else { - Err(Response::builder() - .status(StatusCode::UNAUTHORIZED) - .body(Body::empty()) - .unwrap()) - } - }) - } - } - - #[derive(Debug)] - struct UserId(String); - - #[tokio::test] - async fn require_async_auth_works() { - let mut service = ServiceBuilder::new() - .layer(PyMiddlewareLayer::new(MyAuth)) - .service_fn(echo); - - let request = Request::get("/") - .header(header::AUTHORIZATION, "Bearer 69420") - .body(Body::empty()) - .unwrap(); - - let res = service.ready().await.unwrap().call(request).await.unwrap(); - - assert_eq!(res.status(), StatusCode::OK); - } - - #[tokio::test] - async fn require_async_auth_401() { - let mut service = ServiceBuilder::new() - .layer(PyMiddlewareLayer::new(MyAuth)) - .service_fn(echo); - - let request = Request::get("/") - .header(header::AUTHORIZATION, "Bearer deez") - .body(Body::empty()) - .unwrap(); - - let res = service.ready().await.unwrap().call(request).await.unwrap(); - - assert_eq!(res.status(), StatusCode::UNAUTHORIZED); - } - - async fn echo(req: Request) -> Result, BoxError> { - Ok(Response::new(req.into_body())) - } -} diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index bcda8dd388..e944efe193 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -4,30 +4,16 @@ */ // Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT. -use std::{ - collections::HashMap, - ops::Deref, - pin::Pin, - process, - task::{Context, Poll}, - thread, -}; +use std::{collections::HashMap, ops::Deref, process, thread}; use aws_smithy_http_server::{AddExtensionLayer, Router}; -use bytes::Bytes; -use futures::Future; -use http::{request::Parts, Request, Response}; use parking_lot::Mutex; -use pin_project_lite::pin_project; -use pyo3::{exceptions::PyException, prelude::*, types::IntoPyDict}; +use pyo3::{prelude::*, types::IntoPyDict}; use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; -use tower::{util::ServiceFn, Layer, Service, ServiceBuilder}; +use tower::ServiceBuilder; -use crate::{ - middleware::{PyMiddlewareHandler, PyMiddlewareLayer}, - PySocket, PyMiddlewareHandlers, -}; +use crate::{PyMiddlewareHandlers, PySocket, PyMiddlewareHandler}; /// A Python handler function representation. /// @@ -247,8 +233,7 @@ event_loop.add_signal_handler(signal.SIGINT, rt.block_on(async move { tracing::debug!("Add middlewares to Rust Python router"); let service = ServiceBuilder::new() - .layer(AddExtensionLayer::new(context)) - .layer(PyMiddlewareLayer::new(middlewares)); + .layer(AddExtensionLayer::new(context)); let app = router.layer(service); let server = hyper::Server::from_tcp( raw_socket @@ -454,116 +439,3 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } } - -#[pyclass] -#[derive(Debug)] -struct PyRequest(Request); - -impl Clone for PyRequest { - fn clone(&self) -> Self { - let mut request = Request::builder() - .uri(self.0.uri()) - .method(self.0.method()) - .body(self.0.body().clone()) - .unwrap(); - let headers = request.headers_mut(); - *headers = self.0.headers().clone(); - Self(request) - } -} - -#[pyclass] -#[derive(Debug)] -struct PyResponse(Response); - -impl Clone for PyResponse { - fn clone(&self) -> Self { - let mut response = Response::builder() - .status(self.0.status()) - .body(self.0.body().clone()) - .unwrap(); - let headers = response.headers_mut(); - *headers = self.0.headers().clone(); - Self(response) - } -} - -#[derive(Debug, Clone)] -pub struct PyMiddleware { - handler: PyHandler, -} - -impl PyMiddleware { - /// Create a new [`TimeoutLayer`]. - pub fn new(handler: PyHandler) -> Self { - PyMiddleware { handler } - } -} - -impl Layer for PyMiddleware { - type Service = Middleware; - - fn layer(&self, inner: S) -> Self::Service { - Middleware::new(inner, self.handler.clone()) - } -} - -#[derive(Debug, Clone)] -pub struct Middleware { - inner: S, - handler: PyHandler, -} - -impl Middleware { - /// Create a new [`Timeout`]. - pub fn new(inner: S, handler: PyHandler) -> Self { - Self { inner, handler } - } - - /// Returns a new [`Layer`] that wraps services with a `Timeout` middleware. - /// - /// [`Layer`]: tower_layer::Layer - pub fn layer(handler: PyHandler) -> PyMiddleware { - PyMiddleware::new(handler) - } -} - -impl Service> for Middleware -where - S: Service, Response = Response>, -{ - type Response = S::Response; - type Error = S::Error; - type Future = ResponseFuture; - - #[inline] - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Request) -> Self::Future { - let response = self.inner.call(req); - ResponseFuture { inner: response } - } -} - -pin_project! { - /// Response future for [`Timeout`]. - pub struct ResponseFuture { - #[pin] - inner: F, - } -} - -impl Future for ResponseFuture -where - F: Future, E>>, - B: Default, -{ - type Output = Result, E>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - this.inner.poll(cx) - } -} From aeeaa89f620f8cab282bfd902a8983742f7032d6 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Thu, 15 Sep 2022 19:06:40 -0700 Subject: [PATCH 03/30] Codegenerate middleware support --- .../PythonServerCodegenDecorator.kt | 1 + .../generators/PythonApplicationGenerator.kt | 95 +++++- .../generators/PythonServerModuleGenerator.kt | 17 ++ .../PythonServerOperationHandlerGenerator.kt | 18 +- .../AdditionalErrorsDecorator.kt | 51 +++- .../examples/pokemon_service.py | 25 +- .../aws-smithy-http-server-python/src/lib.rs | 5 +- .../src/middleware.rs | 289 ------------------ .../src/middleware/layer.rs | 143 +++++++++ .../src/middleware/mod.rs | 71 +++++ .../src/middleware/request.rs | 66 ++++ .../src/server.rs | 3 +- 12 files changed, 461 insertions(+), 323 deletions(-) delete mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt index 9ab914f01b..233b477022 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt @@ -22,6 +22,7 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRuntimeType import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerModuleGenerator import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddInternalServerErrorToAllOperationsDecorator +import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddMiddlewareErrorToAllOperationsDecorator /** * Configure the [lib] section of `Cargo.toml`. diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index a55cb4f23d..d02f339d77 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -13,6 +13,7 @@ import software.amazon.smithy.rust.codegen.client.rustlang.asType import software.amazon.smithy.rust.codegen.client.rustlang.rust import software.amazon.smithy.rust.codegen.client.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate +import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.Errors import software.amazon.smithy.rust.codegen.client.smithy.Inputs @@ -30,7 +31,6 @@ import software.amazon.smithy.rust.codegen.server.smithy.ServerCargoDependency * Example: * from pool import DatabasePool * from my_library import App, OperationInput, OperationOutput - * @dataclass * class Context: * db = DatabasePool() @@ -73,6 +73,7 @@ class PythonApplicationGenerator( arrayOf( "SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(), "SmithyServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(), + "http" to CargoDependency.Http.asType(), "pyo3" to PythonServerCargoDependency.PyO3.asType(), "pyo3_asyncio" to PythonServerCargoDependency.PyO3Asyncio.asType(), "tokio" to PythonServerCargoDependency.Tokio.asType(), @@ -92,6 +93,7 @@ class PythonApplicationGenerator( renderPyAppTrait(writer) renderAppImpl(writer) renderPyMethods(writer) + renderPyMiddleware(writer) } fun renderAppStruct(writer: RustWriter) { @@ -101,6 +103,7 @@ class PythonApplicationGenerator( ##[derive(Debug, Default)] pub struct App { handlers: #{HashMap}, + middlewares: Vec<#{SmithyPython}::PyMiddlewareHandler>, context: Option<#{pyo3}::PyObject>, workers: #{parking_lot}::Mutex>, } @@ -116,6 +119,7 @@ class PythonApplicationGenerator( fn clone(&self) -> Self { Self { handlers: self.handlers.clone(), + middlewares: self.middlewares.clone(), context: self.context.clone(), workers: #{parking_lot}::Mutex::new(vec![]), } @@ -151,7 +155,7 @@ class PythonApplicationGenerator( val name = operationName.toSnakeCase() rustTemplate( """ - let ${name}_locals = pyo3_asyncio::TaskLocals::new(event_loop); + let ${name}_locals = #{pyo3_asyncio}::TaskLocals::new(event_loop); let handler = self.handlers.get("$name").expect("Python handler for operation `$name` not found").clone(); let router = router.$name(move |input, state| { #{pyo3_asyncio}::tokio::scope(${name}_locals, crate::operation_handler::$name(input, state, handler)) @@ -162,11 +166,19 @@ class PythonApplicationGenerator( } rustTemplate( """ + let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop); + let middlewares = PyMiddlewareHandlers { + handlers: self.middlewares.clone(), + locals: middleware_locals, + }; + let service = #{tower}::ServiceBuilder::new().layer( + #{SmithyPython}::PyMiddlewareLayer::new(middlewares) + ); let router: #{SmithyServer}::Router = router .build() .expect("Unable to build operation registry") .into(); - Ok(router) + Ok(router.layer(service)) """, *codegenScope, ) @@ -181,14 +193,15 @@ class PythonApplicationGenerator( fn workers(&self) -> &#{parking_lot}::Mutex> { &self.workers } - fn context(&self) -> &Option<#{pyo3}::PyObject> { &self.context } - fn handlers(&mut self) -> &mut #{HashMap} { &mut self.handlers } + fn middlewares(&mut self) -> &mut Vec<#{SmithyPython}::PyMiddlewareHandler> { + &mut self.middlewares + } } """, *codegenScope, @@ -217,6 +230,18 @@ class PythonApplicationGenerator( pub fn context(&mut self, context: #{pyo3}::PyObject) { self.context = Some(context); } + /// Register a middleware function that will be run inside a Tower layer, without cloning the body. + ##[pyo3(text_signature = "(${'$'}self, func)")] + pub fn middleware(&mut self, py: pyo3::Python, func: pyo3::PyObject) -> pyo3::PyResult<()> { + use #{SmithyPython}::PyApp; + self.register_middleware(py, func, false) + } + /// Register a middleware function that will be run inside a Tower layer, cloning the body. + ##[pyo3(text_signature = "(${'$'}self, func)")] + pub fn middleware_with_body(&mut self, py: pyo3::Python, func: pyo3::PyObject) -> pyo3::PyResult<()> { + use #{SmithyPython}::PyApp; + self.register_middleware(py, func, true) + } /// Main entrypoint: start the server on multiple workers. ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")] pub fn run( @@ -265,6 +290,66 @@ class PythonApplicationGenerator( } } + private fun renderPyMiddleware(writer: RustWriter) { + writer.rustTemplate(""" + ##[derive(Debug, Clone)] + struct PyMiddlewareHandlers { + handlers: Vec<#{SmithyPython}::PyMiddlewareHandler>, + locals: #{pyo3_asyncio}::TaskLocals + } + + impl #{SmithyPython}::PyMiddleware for PyMiddlewareHandlers + where + B: Send + Sync + 'static, + { + type RequestBody = B; + type ResponseBody = #{SmithyServer}::body::BoxBody; + type Future = futures_util::future::BoxFuture< + 'static, + Result<#{http}::Request, #{http}::Response>, + >; + + fn run(&mut self, request: http::Request) -> Self::Future { + let handlers = self.handlers.clone(); + let locals = self.locals.clone(); + Box::pin(async move { + // Run all Python handlers in a loop. + for handler in handlers { + let pyrequest = if handler.with_body { + #{SmithyPython}::PyRequest::new_with_body(&request).await + } else { + #{SmithyPython}::PyRequest::new(&request) + }; + let loop_locals = locals.clone(); + let result = #{pyo3_asyncio}::tokio::scope( + loop_locals, + #{SmithyPython}::py_middleware_wrapper(pyrequest, handler), + ); + if let Err(e) = result.await { + let error = crate::operation_ser::serialize_structure_crate_error_internal_server_error( + &e.into() + ).unwrap(); + let boxed_error = #{SmithyServer}::body::boxed(error); + return Err(#{http}::Response::builder() + .status(500) + .body(boxed_error) + .unwrap()); + } + } + Ok(request) + }) + } + } + impl std::convert::From for crate::error::InternalServerError { + fn from(variant: pyo3::PyErr) -> Self { + crate::error::InternalServerError { + message: variant.to_string(), + } + } + } + """, *codegenScope) + } + private fun renderPyApplicationRustDocs(writer: RustWriter) { writer.rust( """ diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index cf9631047d..538a6e95c6 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -47,6 +47,7 @@ class PythonServerModuleGenerator( renderPyCodegeneratedTypes() renderPyWrapperTypes() renderPySocketType() + renderPyMiddlewareTypes() renderPyApplicationType() } } @@ -125,6 +126,22 @@ class PythonServerModuleGenerator( ) } + private fun RustWriter.renderPyMiddlewareTypes() { + rustTemplate( + """ + let middleware = #{pyo3}::types::PyModule::new(py, "middleware")?; + middleware.add_class::<#{SmithyPython}::PyRequest>()?; + pyo3::py_run!( + py, + middleware, + "import sys; sys.modules['libpokemon_service_server_sdk.middleware'] = middleware" + ); + m.add_submodule(middleware)?; + """, + *codegenScope + ) + } + // Render Python application type. private fun RustWriter.renderPyApplicationType() { rustTemplate( diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index c624cb1385..a158d67722 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -90,16 +90,14 @@ class PythonServerOperationHandlerGenerator( rustTemplate( """ #{tracing}::debug!("Executing Python handler function `$name()`"); - #{tokio}::task::block_in_place(move || { - #{pyo3}::Python::with_gil(|py| { - let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?; - let output = if handler.args == 1 { - pyhandler.call1((input,))? - } else { - pyhandler.call1((input, state.0))? - }; - output.extract::<$output>() - }) + #{pyo3}::Python::with_gil(|py| { + let pyhandler: &#{pyo3}::types::PyFunction = handler.extract(py)?; + let output = if handler.args == 1 { + pyhandler.call1((input,))? + } else { + pyhandler.call1((input, state.0))? + }; + output.extract::<$output>() }) """, *codegenScope, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt index b4033fe7a1..8473e817f0 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt @@ -36,8 +36,10 @@ class AddInternalServerErrorToInfallibleOperationsDecorator : RustCodegenDecorat override val name: String = "AddInternalServerErrorToInfallibleOperations" override val order: Byte = 0 - override fun transformModel(service: ServiceShape, model: Model): Model = - addErrorShapeToModelOperations(service, model) { shape -> shape.allErrors(model).isEmpty() } + override fun transformModel(service: ServiceShape, model: Model): Model { + val errorShape = internalServerError(service.id.namespace) + return addErrorShapeToModelOperations(errorShape, model) { shape -> shape.allErrors(model).isEmpty() } + } override fun supportsCodegenContext(clazz: Class): Boolean = clazz.isAssignableFrom(ServerCodegenContext::class.java) @@ -65,19 +67,33 @@ class AddInternalServerErrorToAllOperationsDecorator : RustCodegenDecorator): Boolean = + clazz.isAssignableFrom(ServerCodegenContext::class.java) +} + +class AddMiddlewareErrorToAllOperationsDecorator : RustCodegenDecorator { + override val name: String = "AddMiddlewareErrorToAllOperations" + override val order: Byte = 0 + + override fun transformModel(service: ServiceShape, model: Model): Model { + val errorShape = middlwareError(service.id.namespace) + return addErrorShapeToModelOperations(errorShape, model) { true } + } override fun supportsCodegenContext(clazz: Class): Boolean = clazz.isAssignableFrom(ServerCodegenContext::class.java) } -fun addErrorShapeToModelOperations(service: ServiceShape, model: Model, opSelector: (OperationShape) -> Boolean): Model { - val errorShape = internalServerError(service.id.namespace) - val modelShapes = model.toBuilder().addShapes(listOf(errorShape)).build() +fun addErrorShapeToModelOperations(error: StructureShape, model: Model, opSelector: (OperationShape) -> Boolean): Model { + val modelShapes = model.toBuilder().addShapes(listOf(error)).build() return ModelTransformer.create().mapShapes(modelShapes) { shape -> if (shape is OperationShape && opSelector(shape)) { - shape.toBuilder().addError(errorShape).build() + shape.toBuilder().addError(error).build() } else { shape } @@ -94,3 +110,22 @@ private fun internalServerError(namespace: String): StructureShape = .addTrait(RequiredTrait()) .build(), ).build() + + +private fun middlwareError(namespace: String): StructureShape = + StructureShape.builder().id("$namespace#MiddlewareError") + .addTrait(ErrorTrait("server")) + .addMember( + MemberShape.builder() + .id("$namespace#MiddlewareError\$message") + .target("smithy.api#String") + .addTrait(RequiredTrait()) + .build(), + ) + .addMember( + MemberShape.builder() + .id("$namespace#MiddlewareError\$code") + .target("smithy.api#Integer") + .addTrait(RequiredTrait()) + .build(), + ).build() diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 1a6584072e..e943c1d766 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -14,10 +14,10 @@ from libpokemon_service_server_sdk import App from libpokemon_service_server_sdk.error import ResourceNotFoundException -from libpokemon_service_server_sdk.http import Request from libpokemon_service_server_sdk.input import ( EmptyOperationInput, GetPokemonSpeciesInput, GetServerStatisticsInput, HealthCheckOperationInput, StreamPokemonRadioOperationInput) +from libpokemon_service_server_sdk.middleware import Request from libpokemon_service_server_sdk.model import FlavorText, Language from libpokemon_service_server_sdk.output import ( EmptyOperationOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput, @@ -118,16 +118,25 @@ def get_random_radio_stream(self) -> str: # Middleware ########################################################### @app.middleware -def check_header(request: Request): - logging.info("Inside MID1") - logging.info(request) +def check_content_type_header(request: Request): + content_type = request.get_header("content-type") + if content_type == "application/json": + logging.debug("Found valid `application/json` content type") + else: + logging.error(f"Invalid content type: {content_type}") @app.middleware -def check_header2(request: Request): - logging.info("Inside MID2") - logging.info(request) - raise ValueError("Lol") +async def check_method_and_content_length(request: Request): + content_length = request.get_header("content-length") + logging.debug(f"Request method: {request.method()}") + if content_length is not None: + content_length = int(content_length) + logging.debug( + "Request content length: {content_length}" + ) + else: + logging.error(f"Invalid content length: {content_length}") ########################################################### diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index b263a5d54b..3c6c216c21 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -23,7 +23,10 @@ pub use error::Error; #[doc(inline)] pub use logging::LogLevel; #[doc(inline)] -pub use middleware::{PyMiddlewareHandler, PyRequest, PyMiddlewareHandlers, PyMiddleware, middleware_wrapper}; +pub use middleware::{ + py_middleware_wrapper, PyMiddleware, PyMiddlewareException, PyMiddlewareHandler, + PyMiddlewareLayer, PyRequest, +}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware.rs deleted file mode 100644 index beef505d66..0000000000 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware.rs +++ /dev/null @@ -1,289 +0,0 @@ -use aws_smithy_http_server::body::{boxed, BoxBody}; -use bytes::Bytes; -use futures::future::BoxFuture; -use futures_core::ready; -use http::{Request, Response}; -use hyper::Body; -use pin_project_lite::pin_project; -use pyo3::{pyclass, IntoPy, PyAny, PyErr, PyObject, PyResult}; -use std::{ - error::Error, - future::Future, - pin::Pin, - task::{Context, Poll}, -}; -use tower::{Layer, Service}; - -#[pyclass(name = "Request")] -#[derive(Debug)] -pub struct PyRequest(Request); - -impl PyRequest { - pub fn new(request: &Request) -> Self { - let mut self_ = Request::builder() - .uri(request.uri()) - .method(request.method()) - .body(Bytes::new()) - .unwrap(); - let headers = self_.headers_mut(); - *headers = request.headers().clone(); - Self(self_) - } - - pub fn new_with_body(request: &Request) -> Self { - let mut self_ = Request::builder() - .uri(request.uri()) - .method(request.method()) - .body(Bytes::new()) - .unwrap(); - let headers = self_.headers_mut(); - *headers = request.headers().clone(); - Self(self_) - } -} - -impl Clone for PyRequest { - fn clone(&self) -> Self { - let mut request = Request::builder() - .uri(self.0.uri()) - .method(self.0.method()) - .body(self.0.body().clone()) - .unwrap(); - let headers = request.headers_mut(); - *headers = self.0.headers().clone(); - Self(request) - } -} - -#[derive(Debug, Clone)] -pub struct PyMiddlewareHandler { - pub name: String, - pub func: PyObject, - pub is_coroutine: bool, - pub with_body: bool, -} - -#[derive(Debug, Clone)] -pub struct PyMiddlewareHandlers(pub Vec); - -// Our request handler. This is where we would implement the application logic -// for responding to HTTP requests... -pub async fn middleware_wrapper(request: PyRequest, handler: PyMiddlewareHandler) -> PyResult<()> { - let result = if handler.is_coroutine { - tracing::debug!("Executing Python handler coroutine `stream_pokemon_radio_operation()`"); - let result = pyo3::Python::with_gil(|py| { - let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; - let coroutine = pyhandler.call1((request,))?; - pyo3_asyncio::tokio::into_future(coroutine) - })?; - result.await.map(|_| ()) - } else { - tracing::debug!("Executing Python handler function `stream_pokemon_radio_operation()`"); - tokio::task::block_in_place(move || { - pyo3::Python::with_gil(|py| { - let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; - pyhandler.call1((request,))?; - Ok(()) - }) - }) - }; - // Catch and record a Python traceback. - result.map_err(|e| { - let traceback = pyo3::Python::with_gil(|py| match e.traceback(py) { - Some(t) => t.format().unwrap_or_else(|e| e.to_string()), - None => "Unknown traceback\n".to_string(), - }); - tracing::error!("{}{}", traceback, e); - e - })?; - Ok(()) -} - -impl PyMiddleware for PyMiddlewareHandlers -where - B: Send + Sync + 'static, -{ - type RequestBody = B; - type ResponseBody = BoxBody; - type Future = BoxFuture<'static, Result, Response>>; - - fn run(&mut self, request: Request) -> Self::Future { - let handlers = self.0.clone(); - Box::pin(async move { - for handler in handlers { - let pyrequest = if handler.with_body { - PyRequest::new_with_body(&request) - } else { - PyRequest::new(&request) - }; - middleware_wrapper(pyrequest, handler) - .await - .map_err(|e| into_response(e))?; - } - Ok(request) - }) - } -} - -fn into_response(error: T) -> Response { - Response::builder() - .status(500) - .body(boxed(error.to_string())) - .unwrap() -} - -#[derive(Debug, Clone)] -pub struct PyMiddlewareLayer { - handler: T, -} - -impl PyMiddlewareLayer { - /// Authorize requests using a custom scheme. - pub fn new(handler: T) -> PyMiddlewareLayer { - Self { handler } - } -} - -impl Layer for PyMiddlewareLayer -where - T: Clone, -{ - type Service = PyMiddlewareService; - - fn layer(&self, inner: S) -> Self::Service { - PyMiddlewareService::new(inner, self.handler.clone()) - } -} - -/// Middleware that authorizes all requests using the [`Authorization`] header. -/// -/// See the [module docs](crate::auth::async_require_authorization) for an example. -/// -/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization -#[derive(Clone, Debug)] -pub struct PyMiddlewareService { - inner: S, - handler: T, -} - -impl PyMiddlewareService { - /// Authorize requests using a custom scheme. - /// - /// The `Authorization` header is required to have the value provided. - pub fn new(inner: S, handler: T) -> PyMiddlewareService { - Self { inner, handler } - } - - /// Returns a new [`Layer`] that wraps services with an [`AsyncRequireAuthorizationLayer`] - /// middleware. - /// - /// [`Layer`]: tower_layer::Layer - pub fn layer(handler: T) -> PyMiddlewareLayer { - PyMiddlewareLayer::new(handler) - } -} - -impl Service> for PyMiddlewareService -where - M: PyMiddleware, - S: Service, Response = Response> + Clone, -{ - type Response = Response; - type Error = S::Error; - type Future = ResponseFuture; - - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { - self.inner.poll_ready(cx) - } - - fn call(&mut self, req: Request) -> Self::Future { - let inner = self.inner.clone(); - let run = self.handler.run(req); - - ResponseFuture { - middleware: State::Run { run }, - service: inner, - } - } -} - -pin_project! { - /// Response future for [`AsyncRequireAuthorization`]. - pub struct ResponseFuture - where - M: PyMiddleware, - S: Service>, - { - #[pin] - middleware: State, - service: S, - } -} - -pin_project! { - #[project = StateProj] - enum State { - Run { - #[pin] - run: A, - }, - Done { - #[pin] - fut: Fut - } - } -} - -impl Future for ResponseFuture -where - M: PyMiddleware, - S: Service, Response = Response>, -{ - type Output = Result, S::Error>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); - loop { - match this.middleware.as_mut().project() { - StateProj::Run { run } => { - let run = ready!(run.poll(cx)); - match run { - Ok(req) => { - let fut = this.service.call(req); - this.middleware.set(State::Done { fut }); - } - Err(res) => return Poll::Ready(Ok(res)), - } - } - StateProj::Done { fut } => return fut.poll(cx), - } - } - } -} - -/// Trait for authorizing requests. -pub trait PyMiddleware { - type RequestBody; - type ResponseBody; - /// The Future type returned by `authorize` - type Future: Future, Response>>; - - /// Authorize the request. - /// - /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not. - fn run(&mut self, request: Request) -> Self::Future; -} - -impl PyMiddleware for F -where - F: FnMut(Request) -> Fut, - Fut: Future, Response>>, -{ - type RequestBody = ReqBody; - type ResponseBody = ResBody; - type Future = Fut; - - fn run(&mut self, request: Request) -> Self::Future { - self(request) - } -} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs new file mode 100644 index 0000000000..15af036ef4 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -0,0 +1,143 @@ +use std::{task::{Context, Poll}, pin::Pin}; + +use futures::{Future, ready}; +use http::{Request, Response}; +use pin_project_lite::pin_project; +use tower::{Layer, Service}; + +#[derive(Debug, Clone)] +pub struct PyMiddlewareLayer { + handler: T, +} + +impl PyMiddlewareLayer { + pub fn new(handler: T) -> PyMiddlewareLayer { + Self { handler } + } +} + +impl Layer for PyMiddlewareLayer +where + T: Clone, +{ + type Service = PyMiddlewareService; + + fn layer(&self, inner: S) -> Self::Service { + PyMiddlewareService::new(inner, self.handler.clone()) + } +} + +#[derive(Clone, Debug)] +pub struct PyMiddlewareService { + inner: S, + handler: T, +} + +impl PyMiddlewareService { + pub fn new(inner: S, handler: T) -> PyMiddlewareService { + Self { inner, handler } + } + + pub fn layer(handler: T) -> PyMiddlewareLayer { + PyMiddlewareLayer::new(handler) + } +} + +impl Service> for PyMiddlewareService +where + M: PyMiddleware, + S: Service, Response = Response> + Clone, +{ + type Response = Response; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + let inner = self.inner.clone(); + let run = self.handler.run(req); + + ResponseFuture { + middleware: State::Run { run }, + service: inner, + } + } +} + +pin_project! { + pub struct ResponseFuture + where + M: PyMiddleware, + S: Service>, + { + #[pin] + middleware: State, + service: S, + } +} + +pin_project! { + #[project = StateProj] + enum State { + Run { + #[pin] + run: A, + }, + Done { + #[pin] + fut: Fut + } + } +} + +impl Future for ResponseFuture +where + M: PyMiddleware, + S: Service, Response = Response>, +{ + type Output = Result, S::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + loop { + match this.middleware.as_mut().project() { + StateProj::Run { run } => { + let run = ready!(run.poll(cx)); + match run { + Ok(req) => { + let fut = this.service.call(req); + this.middleware.set(State::Done { fut }); + } + Err(res) => return Poll::Ready(Ok(res)), + } + } + StateProj::Done { fut } => return fut.poll(cx), + } + } + } +} + +pub trait PyMiddleware { + type RequestBody; + type ResponseBody; + type Future: Future, Response>>; + + fn run(&mut self, request: Request) -> Self::Future; +} + +impl PyMiddleware for F +where + F: FnMut(Request) -> Fut, + Fut: Future, Response>>, +{ + type RequestBody = ReqBody; + type ResponseBody = ResBody; + type Future = Fut; + + fn run(&mut self, request: Request) -> Self::Future { + self(request) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs new file mode 100644 index 0000000000..9938a5d897 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -0,0 +1,71 @@ +use pyo3::prelude::*; + +mod layer; +mod request; + +pub use self::layer::{PyMiddleware, PyMiddlewareLayer}; +pub use self::request::PyRequest; + +#[pyclass(name = "MiddlewareException", extends = pyo3::exceptions::PyException)] +#[derive(Debug, Clone)] +pub struct PyMiddlewareException { + #[pyo3(get, set)] + pub message: String, + #[pyo3(get, set)] + pub status_code: u16, +} + +#[pymethods] +impl PyMiddlewareException { + #[new] + fn newpy(message: String, status_code: u16) -> Self { + Self { + message, + status_code, + } + } +} + +#[derive(Debug, Clone)] +pub struct PyMiddlewareHandler { + pub name: String, + pub func: PyObject, + pub is_coroutine: bool, + pub with_body: bool, +} + +// Our request handler. This is where we would implement the application logic +// for responding to HTTP requests... +pub async fn py_middleware_wrapper( + request: PyRequest, + handler: PyMiddlewareHandler, +) -> PyResult<()> { + let result = if handler.is_coroutine { + tracing::debug!("Executing Python handler coroutine `stream_pokemon_radio_operation()`"); + let result = pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + let coroutine = pyhandler.call1((request,))?; + pyo3_asyncio::tokio::into_future(coroutine) + })?; + result.await.map(|_| ()) + } else { + tracing::debug!("Executing Python handler function `stream_pokemon_radio_operation()`"); + tokio::task::block_in_place(move || { + pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + pyhandler.call1((request,))?; + Ok(()) + }) + }) + }; + // Catch and record a Python traceback. + result.map_err(|e| { + let traceback = pyo3::Python::with_gil(|py| match e.traceback(py) { + Some(t) => t.format().unwrap_or_else(|e| e.to_string()), + None => "Unknown traceback\n".to_string(), + }); + tracing::error!("{}{}", traceback, e); + e + })?; + Ok(()) +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs new file mode 100644 index 0000000000..acce79504e --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -0,0 +1,66 @@ +use bytes::Bytes; +use http::Request; +use pyo3::prelude::*; + +#[pyclass(name = "Request")] +#[derive(Debug)] +pub struct PyRequest(Request); + +impl PyRequest { + pub fn new(request: &Request) -> Self { + let mut self_ = Request::builder() + .uri(request.uri()) + .method(request.method()) + .body(Bytes::new()) + .unwrap(); + let headers = self_.headers_mut(); + *headers = request.headers().clone(); + Self(self_) + } + + pub fn new_with_body(request: &Request) -> Self { + let mut self_ = Request::builder() + .uri(request.uri()) + .method(request.method()) + .body(Bytes::new()) + .unwrap(); + let headers = self_.headers_mut(); + *headers = request.headers().clone(); + Self(self_) + } +} + +impl Clone for PyRequest { + fn clone(&self) -> Self { + let mut request = Request::builder() + .uri(self.0.uri()) + .method(self.0.method()) + .body(self.0.body().clone()) + .unwrap(); + let headers = request.headers_mut(); + *headers = self.0.headers().clone(); + Self(request) + } +} + +#[pymethods] +impl PyRequest { + fn method(&self) -> String { + self.0.method().to_string() + } + + fn uri(&self) -> String { + self.0.uri().to_string() + } + + fn get_header(&self, name: &str) -> Option { + let value = self.0.headers().get(name); + match value { + Some(v) => match v.to_str() { + Ok(v) => Some(v.to_string()), + Err(_) => None, + }, + None => None, + } + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index e944efe193..454f4b23f7 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -13,7 +13,7 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; -use crate::{PyMiddlewareHandlers, PySocket, PyMiddlewareHandler}; +use crate::{PySocket, PyMiddlewareHandler}; /// A Python handler function representation. /// @@ -218,7 +218,6 @@ event_loop.add_signal_handler(signal.SIGINT, // Register signals on the Python event loop. self.register_python_signals(py, event_loop.to_object(py))?; - let middlewares = PyMiddlewareHandlers(self.middlewares().clone()); // Spawn a new background [std::thread] to run the application. tracing::debug!("Start the Tokio runtime in a background task"); thread::spawn(move || { From e20563e17150da176a5227ac5fd6c82252974999 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Thu, 15 Sep 2022 23:00:45 -0700 Subject: [PATCH 04/30] Support responses, errors and header updating in middleware chain Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> --- .../generators/PythonApplicationGenerator.kt | 43 ++++++---- .../generators/PythonServerModuleGenerator.kt | 1 + .../examples/pokemon_service.py | 15 +++- .../src/error.rs | 15 +++- .../aws-smithy-http-server-python/src/lib.rs | 4 +- .../src/middleware/mod.rs | 53 +++++++----- .../src/middleware/request.rs | 81 +++++++------------ .../src/middleware/response.rs | 46 +++++++++++ 8 files changed, 161 insertions(+), 97 deletions(-) create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index d02f339d77..f042bbed67 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -309,37 +309,46 @@ class PythonApplicationGenerator( Result<#{http}::Request, #{http}::Response>, >; - fn run(&mut self, request: http::Request) -> Self::Future { + fn run(&mut self, mut request: #{http}::Request) -> Self::Future { let handlers = self.handlers.clone(); let locals = self.locals.clone(); Box::pin(async move { // Run all Python handlers in a loop. for handler in handlers { - let pyrequest = if handler.with_body { - #{SmithyPython}::PyRequest::new_with_body(&request).await - } else { - #{SmithyPython}::PyRequest::new(&request) - }; + let pyrequest = #{SmithyPython}::PyRequest::new(&request); let loop_locals = locals.clone(); let result = #{pyo3_asyncio}::tokio::scope( loop_locals, - #{SmithyPython}::py_middleware_wrapper(pyrequest, handler), - ); - if let Err(e) = result.await { - let error = crate::operation_ser::serialize_structure_crate_error_internal_server_error( - &e.into() - ).unwrap(); - let boxed_error = #{SmithyServer}::body::boxed(error); - return Err(#{http}::Response::builder() - .status(500) - .body(boxed_error) - .unwrap()); + #{SmithyPython}::execute_middleware(pyrequest, handler), + ).await; + match result { + Ok((pyrequest, pyresponse)) => { + if let Some(pyrequest) = pyrequest { + if let Ok(headers) = (&pyrequest.headers).try_into() { + *request.headers_mut() = headers; + } + } + if let Some(pyresponse) = pyresponse { + return Err(pyresponse.try_into().unwrap()); + } + }, + Err(e) => { + let error = crate::operation_ser::serialize_structure_crate_error_internal_server_error( + &e.into() + ).unwrap(); + let boxed_error = aws_smithy_http_server::body::boxed(error); + return Err(http::Response::builder() + .status(500) + .body(boxed_error) + .unwrap()); + } } } Ok(request) }) } } + impl std::convert::From for crate::error::InternalServerError { fn from(variant: pyo3::PyErr) -> Self { crate::error::InternalServerError { diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 538a6e95c6..0fdeaa94c1 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -131,6 +131,7 @@ class PythonServerModuleGenerator( """ let middleware = #{pyo3}::types::PyModule::new(py, "middleware")?; middleware.add_class::<#{SmithyPython}::PyRequest>()?; + middleware.add_class::<#{SmithyPython}::PyResponse>()?; pyo3::py_run!( py, middleware, diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index e943c1d766..57d22bcd7c 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -119,24 +119,31 @@ def get_random_radio_stream(self) -> str: ########################################################### @app.middleware def check_content_type_header(request: Request): - content_type = request.get_header("content-type") + content_type = request.headers.get("content-type") if content_type == "application/json": logging.debug("Found valid `application/json` content type") else: logging.error(f"Invalid content type: {content_type}") +@app.middleware +def modify_request(request: Request): + request.headers["x-amzn-stuff"] = "42" + logging.debug("Setting `x-amzn-stuff` header") + return request + + @app.middleware async def check_method_and_content_length(request: Request): - content_length = request.get_header("content-length") - logging.debug(f"Request method: {request.method()}") + content_length = request.headers.get("content-length") + logging.debug(f"Request method: {request.method}") if content_length is not None: content_length = int(content_length) logging.debug( "Request content length: {content_length}" ) else: - logging.error(f"Invalid content length: {content_length}") + logging.error(f"Invalid content length. Dumping headers: {request.headers}") ########################################################### diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index 42800954be..e38aed26de 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -6,8 +6,9 @@ //! Python error definition. use aws_smithy_types::date_time::{ConversionError, DateTimeParseError}; -use pyo3::{exceptions::PyException, PyErr}; +use pyo3::{exceptions::PyException, PyErr, create_exception}; use thiserror::Error; +use http::{Error as HttpError, status::InvalidStatusCode, header::ToStrError}; /// Python error that implements foreign errors. #[derive(Error, Debug)] @@ -18,10 +19,20 @@ pub enum Error { /// Implements `From`. #[error("DateTimeParse: {0}")] DateTimeParse(#[from] DateTimeParseError), + /// Http errors + #[error("HTTP error: {0}")] + Http(#[from] HttpError), + /// Status code error + #[error("{0}")] + HttpStatusCode(#[from] InvalidStatusCode), + #[error("{0}")] + StrConversion(#[from] ToStrError ) } +create_exception!(smithy, PyError, PyException); + impl From for PyErr { fn from(other: Error) -> PyErr { - PyException::new_err(other.to_string()) + PyError::new_err(other.to_string()) } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 3c6c216c21..66664ff0d6 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -24,8 +24,8 @@ pub use error::Error; pub use logging::LogLevel; #[doc(inline)] pub use middleware::{ - py_middleware_wrapper, PyMiddleware, PyMiddlewareException, PyMiddlewareHandler, - PyMiddlewareLayer, PyRequest, + execute_middleware, PyMiddleware, PyMiddlewareException, PyMiddlewareHandler, + PyMiddlewareLayer, PyRequest, PyResponse, }; #[doc(inline)] pub use server::{PyApp, PyHandler}; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index 9938a5d897..b5d8016cc9 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -2,9 +2,11 @@ use pyo3::prelude::*; mod layer; mod request; +mod response; pub use self::layer::{PyMiddleware, PyMiddlewareLayer}; pub use self::request::PyRequest; +pub use self::response::PyResponse; #[pyclass(name = "MiddlewareException", extends = pyo3::exceptions::PyException)] #[derive(Debug, Clone)] @@ -36,36 +38,45 @@ pub struct PyMiddlewareHandler { // Our request handler. This is where we would implement the application logic // for responding to HTTP requests... -pub async fn py_middleware_wrapper( +pub async fn execute_middleware( request: PyRequest, handler: PyMiddlewareHandler, -) -> PyResult<()> { - let result = if handler.is_coroutine { - tracing::debug!("Executing Python handler coroutine `stream_pokemon_radio_operation()`"); +) -> PyResult<(Option, Option)> { + let handle: PyResult> = if handler.is_coroutine { + tracing::debug!("Executing Python middleware coroutine `{}`", handler.name); let result = pyo3::Python::with_gil(|py| { let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; let coroutine = pyhandler.call1((request,))?; pyo3_asyncio::tokio::into_future(coroutine) })?; - result.await.map(|_| ()) + let output = result.await?; + Ok(output) } else { - tracing::debug!("Executing Python handler function `stream_pokemon_radio_operation()`"); - tokio::task::block_in_place(move || { - pyo3::Python::with_gil(|py| { - let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; - pyhandler.call1((request,))?; - Ok(()) - }) + tracing::debug!("Executing Python middleware function `{}`", handler.name); + pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + let output = pyhandler.call1((request,))?; + Ok(output.into_py(py)) }) }; // Catch and record a Python traceback. - result.map_err(|e| { - let traceback = pyo3::Python::with_gil(|py| match e.traceback(py) { - Some(t) => t.format().unwrap_or_else(|e| e.to_string()), - None => "Unknown traceback\n".to_string(), - }); - tracing::error!("{}{}", traceback, e); - e - })?; - Ok(()) + Python::with_gil(|py| match handle { + Ok(result) => { + if let Ok(request) = result.extract::(py) { + return Ok((Some(request), None)); + } + if let Ok(response) = result.extract::(py) { + return Ok((None, Some(response))); + } + Ok((None, None)) + } + Err(e) => { + let traceback = pyo3::Python::with_gil(|py| match e.traceback(py) { + Some(t) => t.format().unwrap_or_else(|e| e.to_string()), + None => "Unknown traceback\n".to_string(), + }); + tracing::error!("{}{}", traceback, e); + Err(e) + } + }) } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index acce79504e..34e49afba2 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -1,66 +1,45 @@ -use bytes::Bytes; +use std::collections::HashMap; + use http::Request; use pyo3::prelude::*; #[pyclass(name = "Request")] -#[derive(Debug)] -pub struct PyRequest(Request); +#[derive(Debug, Clone)] +pub struct PyRequest { + #[pyo3(get, set)] + method: String, + #[pyo3(get, set)] + uri: String, + #[pyo3(get, set)] + pub headers: HashMap, +} impl PyRequest { pub fn new(request: &Request) -> Self { - let mut self_ = Request::builder() - .uri(request.uri()) - .method(request.method()) - .body(Bytes::new()) - .unwrap(); - let headers = self_.headers_mut(); - *headers = request.headers().clone(); - Self(self_) - } - - pub fn new_with_body(request: &Request) -> Self { - let mut self_ = Request::builder() - .uri(request.uri()) - .method(request.method()) - .body(Bytes::new()) - .unwrap(); - let headers = self_.headers_mut(); - *headers = request.headers().clone(); - Self(self_) - } -} - -impl Clone for PyRequest { - fn clone(&self) -> Self { - let mut request = Request::builder() - .uri(self.0.uri()) - .method(self.0.method()) - .body(self.0.body().clone()) - .unwrap(); - let headers = request.headers_mut(); - *headers = self.0.headers().clone(); - Self(request) + Self { + method: request.method().to_string(), + uri: request.uri().to_string(), + headers: request + .headers() + .into_iter() + .map(|(k, v)| -> (String, String) { + let name: String = k.as_str().to_string(); + let value: String = String::from_utf8_lossy(v.as_bytes()).to_string(); + (name, value) + }) + .collect(), + } } } #[pymethods] impl PyRequest { - fn method(&self) -> String { - self.0.method().to_string() - } - - fn uri(&self) -> String { - self.0.uri().to_string() - } - - fn get_header(&self, name: &str) -> Option { - let value = self.0.headers().get(name); - match value { - Some(v) => match v.to_str() { - Ok(v) => Some(v.to_string()), - Err(_) => None, - }, - None => None, + #[new] + fn newpy(method: String, uri: String, headers: Option>) -> Self { + Self { + method, + uri, + headers: headers.unwrap_or_default(), } } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs new file mode 100644 index 0000000000..4878877c98 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -0,0 +1,46 @@ +use std::{ + collections::HashMap, + convert::{TryFrom, TryInto}, str::FromStr, +}; + +use aws_smithy_http_server::body::{to_boxed, BoxBody}; +use http::{HeaderMap, HeaderValue, Response, StatusCode, header::HeaderName}; +use pyo3::prelude::*; + +use crate::error::PyError; + +#[pyclass(name = "Response")] +#[derive(Debug, Clone)] +pub struct PyResponse { + #[pyo3(get, set)] + status: u16, + #[pyo3(get, set)] + body: Vec, + #[pyo3(get, set)] + headers: HashMap, +} + +#[pymethods] +impl PyResponse { + #[new] + fn newpy(status: u16, headers: Option>, body: Option>) -> Self { + Self { + status, + body: body.unwrap_or_default(), + headers: headers.unwrap_or_default(), + } + } +} + +impl From for Response { + fn from(val: PyResponse) -> Self { + let mut response = Response::builder() + .status(StatusCode::from_u16(val.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + .body(to_boxed(val.body)) + .unwrap_or_default(); + if let Ok(headers) = (&val.headers).try_into() { + *response.headers_mut() = headers; + } + response + } +} From 583c9672840b618e4df3c731446ecdb415870f2e Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 10:23:17 -0700 Subject: [PATCH 05/30] Refactor and make errors consistently working Signed-off-by: Bigo <1781140+crisidev@users.noreply.github.com> --- .../generators/PythonApplicationGenerator.kt | 99 +++----------- .../generators/PythonServerModuleGenerator.kt | 1 + .../aws-smithy-http-server-python/Cargo.toml | 3 +- .../examples/pokemon_service.py | 3 +- .../src/error.rs | 81 +++++++++++- .../aws-smithy-http-server-python/src/lib.rs | 7 +- .../src/middleware/handler.rs | 125 ++++++++++++++++++ .../src/middleware/layer.rs | 75 ++++++----- .../src/middleware/mod.rs | 94 ++++--------- .../src/middleware/response.rs | 9 +- .../src/server.rs | 12 +- .../src/types.rs | 10 +- 12 files changed, 299 insertions(+), 220 deletions(-) create mode 100644 rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index f042bbed67..edd73bf44f 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.Errors import software.amazon.smithy.rust.codegen.client.smithy.Inputs import software.amazon.smithy.rust.codegen.client.smithy.Outputs +import software.amazon.smithy.rust.codegen.client.rustlang.escape import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape @@ -69,6 +70,7 @@ class PythonApplicationGenerator( private val libName = "lib${coreCodegenContext.settings.moduleName.toSnakeCase()}" private val runtimeConfig = coreCodegenContext.runtimeConfig private val model = coreCodegenContext.model + private val protocol = coreCodegenContext.protocol private val codegenScope = arrayOf( "SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(), @@ -93,7 +95,6 @@ class PythonApplicationGenerator( renderPyAppTrait(writer) renderAppImpl(writer) renderPyMethods(writer) - renderPyMiddleware(writer) } fun renderAppStruct(writer: RustWriter) { @@ -103,7 +104,7 @@ class PythonApplicationGenerator( ##[derive(Debug, Default)] pub struct App { handlers: #{HashMap}, - middlewares: Vec<#{SmithyPython}::PyMiddlewareHandler>, + middlewares: #{SmithyPython}::PyMiddlewareHandlers, context: Option<#{pyo3}::PyObject>, workers: #{parking_lot}::Mutex>, } @@ -167,12 +168,13 @@ class PythonApplicationGenerator( rustTemplate( """ let middleware_locals = pyo3_asyncio::TaskLocals::new(event_loop); - let middlewares = PyMiddlewareHandlers { - handlers: self.middlewares.clone(), - locals: middleware_locals, - }; + use #{SmithyPython}::PyApp; let service = #{tower}::ServiceBuilder::new().layer( - #{SmithyPython}::PyMiddlewareLayer::new(middlewares) + #{SmithyPython}::PyMiddlewareLayer::new( + self.middlewares.clone(), + self.protocol(), + middleware_locals + ), ); let router: #{SmithyServer}::Router = router .build() @@ -187,6 +189,7 @@ class PythonApplicationGenerator( } private fun renderPyAppTrait(writer: RustWriter) { + val protocol = protocol.toString().replace("#", "##") writer.rustTemplate( """ impl #{SmithyPython}::PyApp for App { @@ -199,9 +202,12 @@ class PythonApplicationGenerator( fn handlers(&mut self) -> &mut #{HashMap} { &mut self.handlers } - fn middlewares(&mut self) -> &mut Vec<#{SmithyPython}::PyMiddlewareHandler> { + fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewareHandlers { &mut self.middlewares } + fn protocol(&self) -> &'static str { + "$protocol" + } } """, *codegenScope, @@ -234,13 +240,7 @@ class PythonApplicationGenerator( ##[pyo3(text_signature = "(${'$'}self, func)")] pub fn middleware(&mut self, py: pyo3::Python, func: pyo3::PyObject) -> pyo3::PyResult<()> { use #{SmithyPython}::PyApp; - self.register_middleware(py, func, false) - } - /// Register a middleware function that will be run inside a Tower layer, cloning the body. - ##[pyo3(text_signature = "(${'$'}self, func)")] - pub fn middleware_with_body(&mut self, py: pyo3::Python, func: pyo3::PyObject) -> pyo3::PyResult<()> { - use #{SmithyPython}::PyApp; - self.register_middleware(py, func, true) + self.register_middleware(py, func) } /// Main entrypoint: start the server on multiple workers. ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")] @@ -290,75 +290,6 @@ class PythonApplicationGenerator( } } - private fun renderPyMiddleware(writer: RustWriter) { - writer.rustTemplate(""" - ##[derive(Debug, Clone)] - struct PyMiddlewareHandlers { - handlers: Vec<#{SmithyPython}::PyMiddlewareHandler>, - locals: #{pyo3_asyncio}::TaskLocals - } - - impl #{SmithyPython}::PyMiddleware for PyMiddlewareHandlers - where - B: Send + Sync + 'static, - { - type RequestBody = B; - type ResponseBody = #{SmithyServer}::body::BoxBody; - type Future = futures_util::future::BoxFuture< - 'static, - Result<#{http}::Request, #{http}::Response>, - >; - - fn run(&mut self, mut request: #{http}::Request) -> Self::Future { - let handlers = self.handlers.clone(); - let locals = self.locals.clone(); - Box::pin(async move { - // Run all Python handlers in a loop. - for handler in handlers { - let pyrequest = #{SmithyPython}::PyRequest::new(&request); - let loop_locals = locals.clone(); - let result = #{pyo3_asyncio}::tokio::scope( - loop_locals, - #{SmithyPython}::execute_middleware(pyrequest, handler), - ).await; - match result { - Ok((pyrequest, pyresponse)) => { - if let Some(pyrequest) = pyrequest { - if let Ok(headers) = (&pyrequest.headers).try_into() { - *request.headers_mut() = headers; - } - } - if let Some(pyresponse) = pyresponse { - return Err(pyresponse.try_into().unwrap()); - } - }, - Err(e) => { - let error = crate::operation_ser::serialize_structure_crate_error_internal_server_error( - &e.into() - ).unwrap(); - let boxed_error = aws_smithy_http_server::body::boxed(error); - return Err(http::Response::builder() - .status(500) - .body(boxed_error) - .unwrap()); - } - } - } - Ok(request) - }) - } - } - - impl std::convert::From for crate::error::InternalServerError { - fn from(variant: pyo3::PyErr) -> Self { - crate::error::InternalServerError { - message: variant.to_string(), - } - } - } - """, *codegenScope) - } - private fun renderPyApplicationRustDocs(writer: RustWriter) { writer.rust( """ diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 0fdeaa94c1..73872259f0 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -132,6 +132,7 @@ class PythonServerModuleGenerator( let middleware = #{pyo3}::types::PyModule::new(py, "middleware")?; middleware.add_class::<#{SmithyPython}::PyRequest>()?; middleware.add_class::<#{SmithyPython}::PyResponse>()?; + middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?; pyo3::py_run!( py, middleware, diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index e05c100299..fc9a266693 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -13,9 +13,10 @@ Python server runtime for Smithy Rust Server Framework. publish = true [dependencies] +aws-smithy-http = { path = "../aws-smithy-http" } aws-smithy-http-server = { path = "../aws-smithy-http-server" } +aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-types = { path = "../aws-smithy-types" } -aws-smithy-http = { path = "../aws-smithy-http" } bytes = "1.2" futures = "0.3" futures-core = "0.3" diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 57d22bcd7c..4cd819bd20 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -17,7 +17,7 @@ from libpokemon_service_server_sdk.input import ( EmptyOperationInput, GetPokemonSpeciesInput, GetServerStatisticsInput, HealthCheckOperationInput, StreamPokemonRadioOperationInput) -from libpokemon_service_server_sdk.middleware import Request +from libpokemon_service_server_sdk.middleware import Request, MiddlewareException from libpokemon_service_server_sdk.model import FlavorText, Language from libpokemon_service_server_sdk.output import ( EmptyOperationOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput, @@ -124,6 +124,7 @@ def check_content_type_header(request: Request): logging.debug("Found valid `application/json` content type") else: logging.error(f"Invalid content type: {content_type}") + raise MiddlewareException("cmon", 404) @app.middleware diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index e38aed26de..cdc20ff0f5 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -5,14 +5,16 @@ //! Python error definition. +use aws_smithy_http_server::{response::Response, body::to_boxed}; +use aws_smithy_http_server::protocols::Protocol; use aws_smithy_types::date_time::{ConversionError, DateTimeParseError}; -use pyo3::{exceptions::PyException, PyErr, create_exception}; +use pyo3::{exceptions::PyException as BasePyException, PyErr, create_exception, prelude::*}; use thiserror::Error; use http::{Error as HttpError, status::InvalidStatusCode, header::ToStrError}; /// Python error that implements foreign errors. #[derive(Error, Debug)] -pub enum Error { +pub enum PyError { /// Implements `From`. #[error("DateTimeConversion: {0}")] DateTimeConversion(#[from] ConversionError), @@ -29,10 +31,77 @@ pub enum Error { StrConversion(#[from] ToStrError ) } -create_exception!(smithy, PyError, PyException); +create_exception!(smithy, PyException, BasePyException); -impl From for PyErr { - fn from(other: Error) -> PyErr { - PyError::new_err(other.to_string()) +impl From for PyErr { + fn from(other: PyError ) -> PyErr { + PyException::new_err(other.to_string()) + } +} + +#[pyclass(name = "MiddlewareException", extends = BasePyException)] +#[derive(Debug, Clone)] +pub struct PyMiddlewareException { + #[pyo3(get, set)] + pub message: String, + #[pyo3(get, set)] + pub status_code: u16, +} + +#[pymethods] +impl PyMiddlewareException { + #[new] + fn newpy(message: String, status_code: Option) -> Self { + Self { + message, + status_code: status_code.unwrap_or(500), + } + } +} + +impl From for PyMiddlewareException { + fn from(other: PyErr) -> Self { + Self::newpy(other.to_string(), None) + } +} + +impl PyMiddlewareException { + fn json_body(&self) -> String { + let mut out = String::new(); + let mut object = aws_smithy_json::serialize::JsonObjectWriter::new(&mut out); + object.key("message").string(self.message.as_str()); + object.finish(); + out + } + + fn xml_body(&self) -> String { + "".to_string() + } + + pub fn into_response(self, protocol: Protocol) -> Response { + let body = to_boxed(match protocol { + Protocol::RestJson1 => self.json_body(), + Protocol::RestXml => self.xml_body(), + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_0-protocol.html#empty-body-serialization + Protocol::AwsJson10 => self.json_body(), + // See https://awslabs.github.io/smithy/1.0/spec/aws/aws-json-1_1-protocol.html#empty-body-serialization + Protocol::AwsJson11 => self.json_body(), + }); + + let mut builder = http::Response::builder(); + builder = builder.status(self.status_code); + + match protocol { + Protocol::RestJson1 => { + builder = builder + .header("Content-Type", "application/json") + .header("X-Amzn-Errortype", "MiddlewareException"); + } + Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"), + Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"), + Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"), + } + + builder.body(body).expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 66664ff0d6..e79dddf8c7 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -19,14 +19,11 @@ mod socket; pub mod types; #[doc(inline)] -pub use error::Error; +pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] pub use logging::LogLevel; #[doc(inline)] -pub use middleware::{ - execute_middleware, PyMiddleware, PyMiddlewareException, PyMiddlewareHandler, - PyMiddlewareLayer, PyRequest, PyResponse, -}; +pub use middleware::{PyMiddlewareHandlers, PyMiddlewareLayer, PyRequest, PyResponse}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs new file mode 100644 index 0000000000..c6c759e068 --- /dev/null +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -0,0 +1,125 @@ +use aws_smithy_http_server::body::BoxBody; +use futures::future::BoxFuture; +use http::{Request, Response}; +use pyo3::prelude::*; + +use aws_smithy_http_server::protocols::Protocol; +use pyo3_asyncio::TaskLocals; + +use crate::{PyMiddlewareException, PyRequest, PyResponse}; + +use super::PyMiddlewareTrait; + +#[derive(Debug, Clone)] +pub struct PyMiddlewareHandler { + pub name: String, + pub func: PyObject, + pub is_coroutine: bool, +} + +#[derive(Debug, Clone, Default)] +pub struct PyMiddlewareHandlers(Vec); + +impl PyMiddlewareHandlers { + pub fn new(handlers: Vec) -> Self { + Self(handlers) + } + + pub fn push(&mut self, handler: PyMiddlewareHandler) { + self.0.push(handler); + } + + // Our request handler. This is where we would implement the application logic + // for responding to HTTP requests... + async fn execute_middleware( + request: PyRequest, + handler: PyMiddlewareHandler, + ) -> Result<(Option, Option), PyMiddlewareException> { + let handle: PyResult> = if handler.is_coroutine { + tracing::debug!("Executing Python middleware coroutine `{}`", handler.name); + let result = pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + let coroutine = pyhandler.call1((request,))?; + pyo3_asyncio::tokio::into_future(coroutine) + })?; + let output = result.await?; + Ok(output) + } else { + tracing::debug!("Executing Python middleware function `{}`", handler.name); + pyo3::Python::with_gil(|py| { + let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; + let output = pyhandler.call1((request,))?; + Ok(output.into_py(py)) + }) + }; + // Catch and record a Python traceback. + Python::with_gil(|py| match handle { + Ok(result) => { + if let Ok(request) = result.extract::(py) { + return Ok((Some(request), None)); + } + if let Ok(response) = result.extract::(py) { + return Ok((None, Some(response))); + } + Ok((None, None)) + } + Err(e) => pyo3::Python::with_gil(|py| { + let traceback = match e.traceback(py) { + Some(t) => t.format().unwrap_or_else(|e| e.to_string()), + None => "Unknown traceback\n".to_string(), + }; + tracing::error!("{}{}", traceback, e); + let variant = e.value(py); + if let Ok(v) = variant.extract::() { + Err(v) + } else { + Err(e.into()) + } + }), + }) + } +} + +impl PyMiddlewareTrait for PyMiddlewareHandlers +where + B: Send + Sync + 'static, +{ + type RequestBody = B; + type ResponseBody = BoxBody; + type Future = BoxFuture<'static, Result, Response>>; + + fn run( + &mut self, + mut request: http::Request, + protocol: Protocol, + locals: TaskLocals, + ) -> Self::Future { + let handlers = self.0.clone(); + Box::pin(async move { + // Run all Python handlers in a loop. + for handler in handlers { + let pyrequest = PyRequest::new(&request); + let loop_locals = locals.clone(); + let result = pyo3_asyncio::tokio::scope( + loop_locals, + Self::execute_middleware(pyrequest, handler), + ) + .await; + match result { + Ok((pyrequest, pyresponse)) => { + if let Some(pyrequest) = pyrequest { + if let Ok(headers) = (&pyrequest.headers).try_into() { + *request.headers_mut() = headers; + } + } + if let Some(pyresponse) = pyresponse { + return Err(pyresponse.into()); + } + } + Err(e) => return Err(e.into_response(protocol)), + } + } + Ok(request) + }) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index 15af036ef4..75752bbeaf 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -1,18 +1,38 @@ -use std::{task::{Context, Poll}, pin::Pin}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; -use futures::{Future, ready}; +use aws_smithy_http_server::protocols::Protocol; +use futures::{ready, Future}; use http::{Request, Response}; use pin_project_lite::pin_project; +use pyo3_asyncio::TaskLocals; use tower::{Layer, Service}; +use super::PyMiddlewareTrait; + #[derive(Debug, Clone)] pub struct PyMiddlewareLayer { handler: T, + protocol: Protocol, + locals: TaskLocals, } impl PyMiddlewareLayer { - pub fn new(handler: T) -> PyMiddlewareLayer { - Self { handler } + pub fn new(handler: T, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { + let protocol = match protocol { + "aws.protocols#restJson1" => Protocol::RestJson1, + "aws.protocols#restXml" => Protocol::RestXml, + "aws.protocols#awsjson10" => Protocol::AwsJson10, + "aws.protocols#awsjson11" => Protocol::AwsJson11, + _ => panic!(), + }; + Self { + handler, + protocol, + locals, + } } } @@ -23,7 +43,7 @@ where type Service = PyMiddlewareService; fn layer(&self, inner: S) -> Self::Service { - PyMiddlewareService::new(inner, self.handler.clone()) + PyMiddlewareService::new(inner, self.handler.clone(), self.protocol, self.locals.clone()) } } @@ -31,21 +51,28 @@ where pub struct PyMiddlewareService { inner: S, handler: T, + protocol: Protocol, + locals: TaskLocals } impl PyMiddlewareService { - pub fn new(inner: S, handler: T) -> PyMiddlewareService { - Self { inner, handler } + pub fn new(inner: S, handler: T, protocol: Protocol, locals: TaskLocals) -> PyMiddlewareService { + Self { + inner, + handler, + protocol, + locals + } } - pub fn layer(handler: T) -> PyMiddlewareLayer { - PyMiddlewareLayer::new(handler) + pub fn layer(handler: T, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { + PyMiddlewareLayer::new(handler, protocol, locals) } } impl Service> for PyMiddlewareService where - M: PyMiddleware, + M: PyMiddlewareTrait, S: Service, Response = Response> + Clone, { type Response = Response; @@ -58,7 +85,7 @@ where fn call(&mut self, req: Request) -> Self::Future { let inner = self.inner.clone(); - let run = self.handler.run(req); + let run = self.handler.run(req, self.protocol, self.locals.clone()); ResponseFuture { middleware: State::Run { run }, @@ -70,7 +97,7 @@ where pin_project! { pub struct ResponseFuture where - M: PyMiddleware, + M: PyMiddlewareTrait, S: Service>, { #[pin] @@ -95,7 +122,7 @@ pin_project! { impl Future for ResponseFuture where - M: PyMiddleware, + M: PyMiddlewareTrait, S: Service, Response = Response>, { type Output = Result, S::Error>; @@ -119,25 +146,3 @@ where } } } - -pub trait PyMiddleware { - type RequestBody; - type ResponseBody; - type Future: Future, Response>>; - - fn run(&mut self, request: Request) -> Self::Future; -} - -impl PyMiddleware for F -where - F: FnMut(Request) -> Fut, - Fut: Future, Response>>, -{ - type RequestBody = ReqBody; - type ResponseBody = ResBody; - type Future = Fut; - - fn run(&mut self, request: Request) -> Self::Future { - self(request) - } -} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index b5d8016cc9..d51ba646f0 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -1,82 +1,36 @@ -use pyo3::prelude::*; - +mod handler; mod layer; mod request; mod response; -pub use self::layer::{PyMiddleware, PyMiddlewareLayer}; +use aws_smithy_http_server::protocols::Protocol; +use futures::Future; +use http::{Request, Response}; +use pyo3_asyncio::TaskLocals; + +pub use self::handler::{PyMiddlewareHandler, PyMiddlewareHandlers}; +pub use self::layer::PyMiddlewareLayer; pub use self::request::PyRequest; pub use self::response::PyResponse; -#[pyclass(name = "MiddlewareException", extends = pyo3::exceptions::PyException)] -#[derive(Debug, Clone)] -pub struct PyMiddlewareException { - #[pyo3(get, set)] - pub message: String, - #[pyo3(get, set)] - pub status_code: u16, -} +pub trait PyMiddlewareTrait { + type RequestBody; + type ResponseBody; + type Future: Future, Response>>; -#[pymethods] -impl PyMiddlewareException { - #[new] - fn newpy(message: String, status_code: u16) -> Self { - Self { - message, - status_code, - } - } + fn run(&mut self, request: Request, protocol: Protocol, locals: TaskLocals) -> Self::Future; } -#[derive(Debug, Clone)] -pub struct PyMiddlewareHandler { - pub name: String, - pub func: PyObject, - pub is_coroutine: bool, - pub with_body: bool, -} +impl PyMiddlewareTrait for F +where + F: FnMut(Request) -> Fut, + Fut: Future, Response>>, +{ + type RequestBody = ReqBody; + type ResponseBody = ResBody; + type Future = Fut; -// Our request handler. This is where we would implement the application logic -// for responding to HTTP requests... -pub async fn execute_middleware( - request: PyRequest, - handler: PyMiddlewareHandler, -) -> PyResult<(Option, Option)> { - let handle: PyResult> = if handler.is_coroutine { - tracing::debug!("Executing Python middleware coroutine `{}`", handler.name); - let result = pyo3::Python::with_gil(|py| { - let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; - let coroutine = pyhandler.call1((request,))?; - pyo3_asyncio::tokio::into_future(coroutine) - })?; - let output = result.await?; - Ok(output) - } else { - tracing::debug!("Executing Python middleware function `{}`", handler.name); - pyo3::Python::with_gil(|py| { - let pyhandler: &pyo3::types::PyFunction = handler.func.extract(py)?; - let output = pyhandler.call1((request,))?; - Ok(output.into_py(py)) - }) - }; - // Catch and record a Python traceback. - Python::with_gil(|py| match handle { - Ok(result) => { - if let Ok(request) = result.extract::(py) { - return Ok((Some(request), None)); - } - if let Ok(response) = result.extract::(py) { - return Ok((None, Some(response))); - } - Ok((None, None)) - } - Err(e) => { - let traceback = pyo3::Python::with_gil(|py| match e.traceback(py) { - Some(t) => t.format().unwrap_or_else(|e| e.to_string()), - None => "Unknown traceback\n".to_string(), - }); - tracing::error!("{}{}", traceback, e); - Err(e) - } - }) + fn run(&mut self, request: Request, _protocol: Protocol, _locals: TaskLocals) -> Self::Future { + self(request) + } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs index 4878877c98..5fe009cf30 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -1,14 +1,9 @@ -use std::{ - collections::HashMap, - convert::{TryFrom, TryInto}, str::FromStr, -}; +use std::{collections::HashMap, convert::TryInto}; use aws_smithy_http_server::body::{to_boxed, BoxBody}; -use http::{HeaderMap, HeaderValue, Response, StatusCode, header::HeaderName}; +use http::{Response, StatusCode}; use pyo3::prelude::*; -use crate::error::PyError; - #[pyclass(name = "Response")] #[derive(Debug, Clone)] pub struct PyResponse { diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 454f4b23f7..4b6c0e0089 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -13,7 +13,7 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; -use crate::{PySocket, PyMiddlewareHandler}; +use crate::{PySocket, middleware::PyMiddlewareHandler, PyMiddlewareHandlers}; /// A Python handler function representation. /// @@ -61,7 +61,9 @@ pub trait PyApp: Clone + pyo3::IntoPy { /// Mapping between operation names and their `PyHandler` representation. fn handlers(&mut self) -> &mut HashMap; - fn middlewares(&mut self) -> &mut Vec; + fn middlewares(&mut self) -> &mut PyMiddlewareHandlers; + + fn protocol(&self) -> &'static str; /// Handle the graceful termination of Python workers by looping through all the /// active workers and calling `terminate()` on them. If termination fails, this @@ -255,7 +257,7 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } - fn register_middleware(&mut self, py: Python, func: PyObject, with_body: bool) -> PyResult<()> { + fn register_middleware(&mut self, py: Python, func: PyObject) -> PyResult<()> { let inspect = py.import("inspect")?; // Check if the function is a coroutine. // NOTE: that `asyncio.iscoroutine()` doesn't work here. @@ -268,13 +270,11 @@ event_loop.add_signal_handler(signal.SIGINT, name, func, is_coroutine, - with_body, }; tracing::info!( - "Registering middleware function `{}`, coroutine: {}, with_body: {}", + "Registering middleware function `{}`, coroutine: {}", handler.name, handler.is_coroutine, - handler.with_body, ); self.middlewares().push(handler); Ok(()) diff --git a/rust-runtime/aws-smithy-http-server-python/src/types.rs b/rust-runtime/aws-smithy-http-server-python/src/types.rs index 9e288badbc..98adbd4119 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/types.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/types.rs @@ -22,7 +22,7 @@ use pyo3::{ use tokio::sync::Mutex; use tokio_stream::StreamExt; -use crate::Error; +use crate::PyError; /// Python Wrapper for [aws_smithy_types::Blob]. #[pyclass] @@ -152,7 +152,7 @@ impl DateTime { pub fn from_nanos(epoch_nanos: i128) -> PyResult { Ok(Self( aws_smithy_types::date_time::DateTime::from_nanos(epoch_nanos) - .map_err(Error::DateTimeConversion)?, + .map_err(PyError::DateTimeConversion)?, )) } @@ -160,7 +160,7 @@ impl DateTime { #[staticmethod] pub fn read(s: &str, format: Format, delim: char) -> PyResult<(Self, &str)> { let (self_, next) = aws_smithy_types::date_time::DateTime::read(s, format.into(), delim) - .map_err(Error::DateTimeParse)?; + .map_err(PyError::DateTimeParse)?; Ok((Self(self_), next)) } @@ -195,7 +195,7 @@ impl DateTime { pub fn from_str(s: &str, format: Format) -> PyResult { Ok(Self( aws_smithy_types::date_time::DateTime::from_str(s, format.into()) - .map_err(Error::DateTimeParse)?, + .map_err(PyError::DateTimeParse)?, )) } @@ -226,7 +226,7 @@ impl DateTime { /// Converts the `DateTime` to the number of milliseconds since the Unix epoch. pub fn to_millis(&self) -> PyResult { - Ok(self.0.to_millis().map_err(Error::DateTimeConversion)?) + Ok(self.0.to_millis().map_err(PyError::DateTimeConversion)?) } } From 047401cf57be6909182240e2731e853dc62e5075 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:17:41 -0700 Subject: [PATCH 06/30] Fix the ability of changing the request between middlewares --- .../generators/PythonServerModuleGenerator.kt | 1 + .../AdditionalErrorsDecorator.kt | 51 +++-------------- .../aws-smithy-http-server-python/Cargo.toml | 1 + .../examples/pokemon_service.py | 44 ++++++++++---- .../src/error.rs | 35 +++++++++--- .../aws-smithy-http-server-python/src/lib.rs | 2 +- .../src/middleware/handler.rs | 11 +++- .../src/middleware/mod.rs | 2 +- .../src/middleware/request.rs | 57 ++++++++++++++++++- .../src/middleware/response.rs | 26 ++++++--- 10 files changed, 153 insertions(+), 77 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 73872259f0..ad639aa5b6 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -133,6 +133,7 @@ class PythonServerModuleGenerator( middleware.add_class::<#{SmithyPython}::PyRequest>()?; middleware.add_class::<#{SmithyPython}::PyResponse>()?; middleware.add_class::<#{SmithyPython}::PyMiddlewareException>()?; + middleware.add_class::<#{SmithyPython}::PyHttpVersion>()?; pyo3::py_run!( py, middleware, diff --git a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt index 8473e817f0..b4033fe7a1 100644 --- a/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt +++ b/codegen-server/src/main/kotlin/software/amazon/smithy/rust/codegen/server/smithy/customizations/AdditionalErrorsDecorator.kt @@ -36,10 +36,8 @@ class AddInternalServerErrorToInfallibleOperationsDecorator : RustCodegenDecorat override val name: String = "AddInternalServerErrorToInfallibleOperations" override val order: Byte = 0 - override fun transformModel(service: ServiceShape, model: Model): Model { - val errorShape = internalServerError(service.id.namespace) - return addErrorShapeToModelOperations(errorShape, model) { shape -> shape.allErrors(model).isEmpty() } - } + override fun transformModel(service: ServiceShape, model: Model): Model = + addErrorShapeToModelOperations(service, model) { shape -> shape.allErrors(model).isEmpty() } override fun supportsCodegenContext(clazz: Class): Boolean = clazz.isAssignableFrom(ServerCodegenContext::class.java) @@ -67,33 +65,19 @@ class AddInternalServerErrorToAllOperationsDecorator : RustCodegenDecorator): Boolean = - clazz.isAssignableFrom(ServerCodegenContext::class.java) -} - -class AddMiddlewareErrorToAllOperationsDecorator : RustCodegenDecorator { - override val name: String = "AddMiddlewareErrorToAllOperations" - override val order: Byte = 0 - - override fun transformModel(service: ServiceShape, model: Model): Model { - val errorShape = middlwareError(service.id.namespace) - return addErrorShapeToModelOperations(errorShape, model) { true } - } + override fun transformModel(service: ServiceShape, model: Model): Model = + addErrorShapeToModelOperations(service, model) { true } override fun supportsCodegenContext(clazz: Class): Boolean = clazz.isAssignableFrom(ServerCodegenContext::class.java) } -fun addErrorShapeToModelOperations(error: StructureShape, model: Model, opSelector: (OperationShape) -> Boolean): Model { - val modelShapes = model.toBuilder().addShapes(listOf(error)).build() +fun addErrorShapeToModelOperations(service: ServiceShape, model: Model, opSelector: (OperationShape) -> Boolean): Model { + val errorShape = internalServerError(service.id.namespace) + val modelShapes = model.toBuilder().addShapes(listOf(errorShape)).build() return ModelTransformer.create().mapShapes(modelShapes) { shape -> if (shape is OperationShape && opSelector(shape)) { - shape.toBuilder().addError(error).build() + shape.toBuilder().addError(errorShape).build() } else { shape } @@ -110,22 +94,3 @@ private fun internalServerError(namespace: String): StructureShape = .addTrait(RequiredTrait()) .build(), ).build() - - -private fun middlwareError(namespace: String): StructureShape = - StructureShape.builder().id("$namespace#MiddlewareError") - .addTrait(ErrorTrait("server")) - .addMember( - MemberShape.builder() - .id("$namespace#MiddlewareError\$message") - .target("smithy.api#String") - .addTrait(RequiredTrait()) - .build(), - ) - .addMember( - MemberShape.builder() - .id("$namespace#MiddlewareError\$code") - .target("smithy.api#Integer") - .addTrait(RequiredTrait()) - .build(), - ).build() diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index fc9a266693..2be120e75e 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -17,6 +17,7 @@ aws-smithy-http = { path = "../aws-smithy-http" } aws-smithy-http-server = { path = "../aws-smithy-http-server" } aws-smithy-json = { path = "../aws-smithy-json" } aws-smithy-types = { path = "../aws-smithy-types" } +aws-smithy-xml = { path = "../aws-smithy-xml" } bytes = "1.2" futures = "0.3" futures-core = "0.3" diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 4cd819bd20..86ce9a0125 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -17,7 +17,8 @@ from libpokemon_service_server_sdk.input import ( EmptyOperationInput, GetPokemonSpeciesInput, GetServerStatisticsInput, HealthCheckOperationInput, StreamPokemonRadioOperationInput) -from libpokemon_service_server_sdk.middleware import Request, MiddlewareException +from libpokemon_service_server_sdk.middleware import (MiddlewareException, + Request) from libpokemon_service_server_sdk.model import FlavorText, Language from libpokemon_service_server_sdk.output import ( EmptyOperationOutput, GetPokemonSpeciesOutput, GetServerStatisticsOutput, @@ -116,35 +117,56 @@ def get_random_radio_stream(self) -> str: ########################################################### # Middleware -########################################################### +############################################################ +# Middlewares are sync or async function decorated by `@app.middleware`. +# They are executed in order and take as input the HTTP request object. +# A middleware can return multiple values, following these rules: +# * Middleware not returning will let the execution continue without +# changing the original request. +# * Middleware returning a modified Request will update the original +# request before continuing the execution. +# * Middleware returnign a Response will immediately terminate the request +# handling and return the response constructed from Python. +# * Middleware raising MiddlewareException will immediately terminate the +# request handling and return a protocol specific error, with the option of +# setting the HTTP return code. +# * Middleware raising any other exception will immediately terminate the +# request handling and return a protocol specific error, with HTTP status +# code 500. @app.middleware def check_content_type_header(request: Request): - content_type = request.headers.get("content-type") + content_type = request.get_header("content-type") if content_type == "application/json": logging.debug("Found valid `application/json` content type") else: logging.error(f"Invalid content type: {content_type}") - raise MiddlewareException("cmon", 404) + # Return an HTTP 401 Unauthorized if the content type is not JSON. + raise MiddlewareException("Invalid content type", 401) + # Check that `x-amzn-answer` header is not set. + assert request.get_header("x-amzn-answer") is None +# This middleware adds a new header called `x-amazon-answer` to the +# request. We expect to see this header to be populated in the next +# middleware. @app.middleware -def modify_request(request: Request): - request.headers["x-amzn-stuff"] = "42" +def add_x_amzn_stuff_header(request: Request): + request.set_header("x-amzn-answer", "42") logging.debug("Setting `x-amzn-stuff` header") return request @app.middleware async def check_method_and_content_length(request: Request): - content_length = request.headers.get("content-length") + content_length = request.get_header("content-length") logging.debug(f"Request method: {request.method}") if content_length is not None: content_length = int(content_length) - logging.debug( - "Request content length: {content_length}" - ) + logging.debug("Request content length: {content_length}") else: - logging.error(f"Invalid content length. Dumping headers: {request.headers}") + logging.error(f"Invalid content length. Dumping headers: {request.headers()}") + # Check that `x-amzn-answer` is 42. + assert request.get_header("x-amzn-answer") == "42" ########################################################### diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index cdc20ff0f5..417b3c9d27 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -5,12 +5,12 @@ //! Python error definition. -use aws_smithy_http_server::{response::Response, body::to_boxed}; use aws_smithy_http_server::protocols::Protocol; +use aws_smithy_http_server::{body::to_boxed, response::Response}; use aws_smithy_types::date_time::{ConversionError, DateTimeParseError}; -use pyo3::{exceptions::PyException as BasePyException, PyErr, create_exception, prelude::*}; +use http::{header::ToStrError, status::InvalidStatusCode, Error as HttpError}; +use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr}; use thiserror::Error; -use http::{Error as HttpError, status::InvalidStatusCode, header::ToStrError}; /// Python error that implements foreign errors. #[derive(Error, Debug)] @@ -28,13 +28,13 @@ pub enum PyError { #[error("{0}")] HttpStatusCode(#[from] InvalidStatusCode), #[error("{0}")] - StrConversion(#[from] ToStrError ) + StrConversion(#[from] ToStrError), } create_exception!(smithy, PyException, BasePyException); impl From for PyErr { - fn from(other: PyError ) -> PyErr { + fn from(other: PyError) -> PyErr { PyException::new_err(other.to_string()) } } @@ -75,7 +75,20 @@ impl PyMiddlewareException { } fn xml_body(&self) -> String { - "".to_string() + let mut out = String::new(); + { + let mut writer = aws_smithy_xml::encode::XmlWriter::new(&mut out); + let root = writer + .start_el("Error") + .write_ns("http://s3.amazonaws.com/doc/2006-03-01/", None); + let mut scope = root.finish(); + { + let mut inner_writer = scope.start_el("Message").finish(); + inner_writer.data(self.message.as_ref()); + } + scope.finish(); + } + out } pub fn into_response(self, protocol: Protocol) -> Response { @@ -98,10 +111,14 @@ impl PyMiddlewareException { .header("X-Amzn-Errortype", "MiddlewareException"); } Protocol::RestXml => builder = builder.header("Content-Type", "application/xml"), - Protocol::AwsJson10 => builder = builder.header("Content-Type", "application/x-amz-json-1.0"), - Protocol::AwsJson11 => builder = builder.header("Content-Type", "application/x-amz-json-1.1"), + Protocol::AwsJson10 => { + builder = builder.header("Content-Type", "application/x-amz-json-1.0") + } + Protocol::AwsJson11 => { + builder = builder.header("Content-Type", "application/x-amz-json-1.1") + } } - builder.body(body).expect("invalid HTTP response for `RuntimeError`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") + builder.body(body).expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index e79dddf8c7..e87564ba60 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -23,7 +23,7 @@ pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] pub use logging::LogLevel; #[doc(inline)] -pub use middleware::{PyMiddlewareHandlers, PyMiddlewareLayer, PyRequest, PyResponse}; +pub use middleware::{PyMiddlewareHandlers, PyMiddlewareLayer, PyRequest, PyResponse, PyHttpVersion}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index c6c759e068..2f7c7b5b0d 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -52,7 +52,6 @@ impl PyMiddlewareHandlers { Ok(output.into_py(py)) }) }; - // Catch and record a Python traceback. Python::with_gil(|py| match handle { Ok(result) => { if let Ok(request) = result.extract::(py) { @@ -97,7 +96,9 @@ where let handlers = self.0.clone(); Box::pin(async move { // Run all Python handlers in a loop. + tracing::debug!("Executing Python middleware stack"); for handler in handlers { + let name = handler.name.clone(); let pyrequest = PyRequest::new(&request); let loop_locals = locals.clone(); let result = pyo3_asyncio::tokio::scope( @@ -109,16 +110,22 @@ where Ok((pyrequest, pyresponse)) => { if let Some(pyrequest) = pyrequest { if let Ok(headers) = (&pyrequest.headers).try_into() { + tracing::debug!("Middleware `{name}` returned an HTTP request, override headers with middleware's one"); *request.headers_mut() = headers; } } if let Some(pyresponse) = pyresponse { + tracing::debug!("Middleware `{name}` returned a HTTP response, exit middleware loop"); return Err(pyresponse.into()); } } - Err(e) => return Err(e.into_response(protocol)), + Err(e) => { + tracing::debug!("Middleware `{name}` returned an error, exit middleware loop"); + return Err(e.into_response(protocol)); + } } } + tracing::debug!("Returning original request to operation handler"); Ok(request) }) } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index d51ba646f0..7eb3e795c6 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -10,7 +10,7 @@ use pyo3_asyncio::TaskLocals; pub use self::handler::{PyMiddlewareHandler, PyMiddlewareHandlers}; pub use self::layer::PyMiddlewareLayer; -pub use self::request::PyRequest; +pub use self::request::{PyRequest, PyHttpVersion}; pub use self::response::PyResponse; pub trait PyMiddlewareTrait { diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index 34e49afba2..c4f6e89a29 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -1,8 +1,19 @@ use std::collections::HashMap; -use http::Request; +use http::{Request, Version}; use pyo3::prelude::*; +#[pyclass(name = "HttpVersion")] +#[derive(PartialEq, PartialOrd, Copy, Clone, Eq, Ord, Hash)] +pub enum PyHttpVersion { + Http09, + Http10, + Http11, + H2, + H3, + __NonExhaustive +} + #[pyclass(name = "Request")] #[derive(Debug, Clone)] pub struct PyRequest { @@ -10,8 +21,8 @@ pub struct PyRequest { method: String, #[pyo3(get, set)] uri: String, - #[pyo3(get, set)] pub headers: HashMap, + version: Version, } impl PyRequest { @@ -28,6 +39,7 @@ impl PyRequest { (name, value) }) .collect(), + version: request.version(), } } } @@ -35,11 +47,50 @@ impl PyRequest { #[pymethods] impl PyRequest { #[new] - fn newpy(method: String, uri: String, headers: Option>) -> Self { + fn newpy( + method: String, + uri: String, + headers: Option>, + version: Option, + ) -> Self { + let version = version + .map(|v| match v { + PyHttpVersion::Http09 => Version::HTTP_09, + PyHttpVersion::Http10 => Version::HTTP_10, + PyHttpVersion::Http11 => Version::HTTP_11, + PyHttpVersion::H2 => Version::HTTP_2, + PyHttpVersion::H3 => Version::HTTP_3, + _ => unreachable!(), + }) + .unwrap_or(Version::HTTP_11); Self { method, uri, headers: headers.unwrap_or_default(), + version, + } + } + + fn version(&self) -> PyHttpVersion { + match self.version { + Version::HTTP_09 => PyHttpVersion::Http09, + Version::HTTP_10 => PyHttpVersion::Http10, + Version::HTTP_11 => PyHttpVersion::Http11, + Version::HTTP_2 => PyHttpVersion::H2, + Version::HTTP_3 => PyHttpVersion::H3, + _ => unreachable!(), } } + + fn headers(&self) -> HashMap { + self.headers.clone() + } + + fn set_header(&mut self, key: &str, value: &str) { + self.headers.insert(key.to_string(), value.to_string()); + } + + fn get_header(&self, key: &str) -> Option<&String> { + self.headers.get(key) + } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs index 5fe009cf30..51c412e64f 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -11,7 +11,6 @@ pub struct PyResponse { status: u16, #[pyo3(get, set)] body: Vec, - #[pyo3(get, set)] headers: HashMap, } @@ -25,17 +24,30 @@ impl PyResponse { headers: headers.unwrap_or_default(), } } + + fn headers(&self) -> HashMap { + self.headers.clone() + } + + fn set_header(&mut self, key: &str, value: &str) { + self.headers.insert(key.to_string(), value.to_string()); + } + + fn get_header(&self, key: &str) -> Option<&String> { + self.headers.get(key) + } } impl From for Response { - fn from(val: PyResponse) -> Self { + fn from(pyresponse: PyResponse) -> Self { let mut response = Response::builder() - .status(StatusCode::from_u16(val.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) - .body(to_boxed(val.body)) + .status(StatusCode::from_u16(pyresponse.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + .body(to_boxed(pyresponse.body)) .unwrap_or_default(); - if let Ok(headers) = (&val.headers).try_into() { - *response.headers_mut() = headers; - } + match (&pyresponse.headers).try_into() { + Ok(headers) => *response.headers_mut() = headers, + Err(e) => tracing::error!("Error extracting HTTP headers from PyResponse: {e}") + }; response } } From aa2dc0a3118e869a1691ba83a16e2601779f9899 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 12:21:17 -0700 Subject: [PATCH 07/30] Remove unused errors --- .../customizations/PythonServerCodegenDecorator.kt | 1 - rust-runtime/aws-smithy-http-server-python/src/error.rs | 9 --------- 2 files changed, 10 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt index 233b477022..9ab914f01b 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/customizations/PythonServerCodegenDecorator.kt @@ -22,7 +22,6 @@ import software.amazon.smithy.rust.codegen.core.util.toSnakeCase import software.amazon.smithy.rust.codegen.server.python.smithy.PythonServerRuntimeType import software.amazon.smithy.rust.codegen.server.python.smithy.generators.PythonServerModuleGenerator import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddInternalServerErrorToAllOperationsDecorator -import software.amazon.smithy.rust.codegen.server.smithy.customizations.AddMiddlewareErrorToAllOperationsDecorator /** * Configure the [lib] section of `Cargo.toml`. diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index 417b3c9d27..9015449439 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -8,7 +8,6 @@ use aws_smithy_http_server::protocols::Protocol; use aws_smithy_http_server::{body::to_boxed, response::Response}; use aws_smithy_types::date_time::{ConversionError, DateTimeParseError}; -use http::{header::ToStrError, status::InvalidStatusCode, Error as HttpError}; use pyo3::{create_exception, exceptions::PyException as BasePyException, prelude::*, PyErr}; use thiserror::Error; @@ -21,14 +20,6 @@ pub enum PyError { /// Implements `From`. #[error("DateTimeParse: {0}")] DateTimeParse(#[from] DateTimeParseError), - /// Http errors - #[error("HTTP error: {0}")] - Http(#[from] HttpError), - /// Status code error - #[error("{0}")] - HttpStatusCode(#[from] InvalidStatusCode), - #[error("{0}")] - StrConversion(#[from] ToStrError), } create_exception!(smithy, PyException, BasePyException); From 5e784f692dddaf8f0f1e010cf65df70c3ff6766a Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 13:36:00 -0700 Subject: [PATCH 08/30] Refactor --- .../generators/PythonApplicationGenerator.kt | 2 +- .../src/middleware/layer.rs | 6 +-- .../src/middleware/mod.rs | 14 ------- .../src/server.rs | 38 ++++++++++--------- 4 files changed, 25 insertions(+), 35 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index edd73bf44f..9a5eea3ac2 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -281,7 +281,7 @@ class PythonApplicationGenerator( ##[pyo3(text_signature = "(${'$'}self, func)")] pub fn $name(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { use #{SmithyPython}::PyApp; - self.register_operation(py, "$name", func) + self.register_operation(py, func) } """, *codegenScope, diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index 75752bbeaf..dcda64bf3d 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -88,7 +88,7 @@ where let run = self.handler.run(req, self.protocol, self.locals.clone()); ResponseFuture { - middleware: State::Run { run }, + middleware: State::Running { run }, service: inner, } } @@ -109,7 +109,7 @@ pin_project! { pin_project! { #[project = StateProj] enum State { - Run { + Running { #[pin] run: A, }, @@ -131,7 +131,7 @@ where let mut this = self.project(); loop { match this.middleware.as_mut().project() { - StateProj::Run { run } => { + StateProj::Running { run } => { let run = ready!(run.poll(cx)); match run { Ok(req) => { diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index 7eb3e795c6..24fc38dadb 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -20,17 +20,3 @@ pub trait PyMiddlewareTrait { fn run(&mut self, request: Request, protocol: Protocol, locals: TaskLocals) -> Self::Future; } - -impl PyMiddlewareTrait for F -where - F: FnMut(Request) -> Fut, - Fut: Future, Response>>, -{ - type RequestBody = ReqBody; - type ResponseBody = ResBody; - type Future = Fut; - - fn run(&mut self, request: Request, _protocol: Protocol, _locals: TaskLocals) -> Self::Future { - self(request) - } -} diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 4b6c0e0089..17d63ea227 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -13,7 +13,7 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; -use crate::{PySocket, middleware::PyMiddlewareHandler, PyMiddlewareHandlers}; +use crate::{middleware::PyMiddlewareHandler, PyMiddlewareHandlers, PySocket}; /// A Python handler function representation. /// @@ -233,9 +233,8 @@ event_loop.add_signal_handler(signal.SIGINT, // all inside a [tokio] blocking function. rt.block_on(async move { tracing::debug!("Add middlewares to Rust Python router"); - let service = ServiceBuilder::new() - .layer(AddExtensionLayer::new(context)); - let app = router.layer(service); + let app = + router.layer(ServiceBuilder::new().layer(AddExtensionLayer::new(context))); let server = hyper::Server::from_tcp( raw_socket .try_into() @@ -257,15 +256,23 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } - fn register_middleware(&mut self, py: Python, func: PyObject) -> PyResult<()> { + fn is_coroutine(&self, py: Python, func: &PyObject) -> PyResult { let inspect = py.import("inspect")?; // Check if the function is a coroutine. // NOTE: that `asyncio.iscoroutine()` doesn't work here. - let is_coroutine = inspect - .call_method1("iscoroutinefunction", (&func,))? - .extract::()?; + inspect + .call_method1("iscoroutinefunction", (func,))? + .extract::() + } + + /// Register a Python function to be executed inside a Tower middleware layer. + /// + /// There are some information needed to execute the Python code from a Rust handler, + /// such has if the registered function needs to be awaited (if it is a coroutine).. + fn register_middleware(&mut self, py: Python, func: PyObject) -> PyResult<()> { let name = func.getattr(py, "__name__")?.extract::(py)?; - // Find number of expected methods (a Pythzzon implementation could not accept the context). + let is_coroutine = self.is_coroutine(py, &func)?; + // Find number of expected methods (a Python implementation could not accept the context). let handler = PyMiddlewareHandler { name, func, @@ -286,14 +293,11 @@ event_loop.add_signal_handler(signal.SIGINT, /// such has if the registered function needs to be awaited (if it is a coroutine) and /// the number of arguments available, which tells us if the handler wants the state to be /// passed or not. - fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> { - let inspect = py.import("inspect")?; - // Check if the function is a coroutine. - // NOTE: that `asyncio.iscoroutine()` doesn't work here. - let is_coroutine = inspect - .call_method1("iscoroutinefunction", (&func,))? - .extract::()?; + fn register_operation(&mut self, py: Python, func: PyObject) -> PyResult<()> { + let name = func.getattr(py, "__name__")?.extract::(py)?; + let is_coroutine = self.is_coroutine(py, &func)?; // Find number of expected methods (a Pythzzon implementation could not accept the context). + let inspect = py.import("inspect")?; let func_args = inspect .call_method1("getargs", (func.getattr(py, "__code__")?,))? .getattr("args")? @@ -309,7 +313,7 @@ event_loop.add_signal_handler(signal.SIGINT, handler.args, ); // Insert the handler in the handlers map. - self.handlers().insert(String::from(name), handler); + self.handlers().insert(name, handler); Ok(()) } From 03c660b02bc4d2a1b39ee7740492ba14bc6045c8 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 15:05:42 -0700 Subject: [PATCH 09/30] Remove trait --- .../generators/PythonApplicationGenerator.kt | 4 +- .../aws-smithy-http-server-python/src/lib.rs | 2 +- .../src/middleware/handler.rs | 36 ++++----- .../src/middleware/layer.rs | 78 ++++++++++--------- .../src/middleware/mod.rs | 12 +-- .../src/middleware/request.rs | 5 +- .../src/server.rs | 4 +- 7 files changed, 68 insertions(+), 73 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 9a5eea3ac2..d988564a44 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -104,7 +104,7 @@ class PythonApplicationGenerator( ##[derive(Debug, Default)] pub struct App { handlers: #{HashMap}, - middlewares: #{SmithyPython}::PyMiddlewareHandlers, + middlewares: #{SmithyPython}::PyMiddlewares, context: Option<#{pyo3}::PyObject>, workers: #{parking_lot}::Mutex>, } @@ -202,7 +202,7 @@ class PythonApplicationGenerator( fn handlers(&mut self) -> &mut #{HashMap} { &mut self.handlers } - fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewareHandlers { + fn middlewares(&mut self) -> &mut #{SmithyPython}::PyMiddlewares { &mut self.middlewares } fn protocol(&self) -> &'static str { diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index e87564ba60..b455f03d7a 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -23,7 +23,7 @@ pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] pub use logging::LogLevel; #[doc(inline)] -pub use middleware::{PyMiddlewareHandlers, PyMiddlewareLayer, PyRequest, PyResponse, PyHttpVersion}; +pub use middleware::{PyMiddlewares, PyMiddlewareLayer, PyRequest, PyResponse, PyHttpVersion}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index 2f7c7b5b0d..5702f9dede 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -1,6 +1,5 @@ -use aws_smithy_http_server::body::BoxBody; -use futures::future::BoxFuture; -use http::{Request, Response}; +use aws_smithy_http_server::body::Body; +use http::Request; use pyo3::prelude::*; use aws_smithy_http_server::protocols::Protocol; @@ -8,7 +7,7 @@ use pyo3_asyncio::TaskLocals; use crate::{PyMiddlewareException, PyRequest, PyResponse}; -use super::PyMiddlewareTrait; +use super::PyFuture; #[derive(Debug, Clone)] pub struct PyMiddlewareHandler { @@ -18,9 +17,9 @@ pub struct PyMiddlewareHandler { } #[derive(Debug, Clone, Default)] -pub struct PyMiddlewareHandlers(Vec); +pub struct PyMiddlewares(Vec); -impl PyMiddlewareHandlers { +impl PyMiddlewares { pub fn new(handlers: Vec) -> Self { Self(handlers) } @@ -77,25 +76,16 @@ impl PyMiddlewareHandlers { }), }) } -} - -impl PyMiddlewareTrait for PyMiddlewareHandlers -where - B: Send + Sync + 'static, -{ - type RequestBody = B; - type ResponseBody = BoxBody; - type Future = BoxFuture<'static, Result, Response>>; - fn run( + pub fn run( &mut self, - mut request: http::Request, + mut request: Request, protocol: Protocol, locals: TaskLocals, - ) -> Self::Future { + ) -> PyFuture { let handlers = self.0.clone(); + // Run all Python handlers in a loop. Box::pin(async move { - // Run all Python handlers in a loop. tracing::debug!("Executing Python middleware stack"); for handler in handlers { let name = handler.name.clone(); @@ -115,12 +105,16 @@ where } } if let Some(pyresponse) = pyresponse { - tracing::debug!("Middleware `{name}` returned a HTTP response, exit middleware loop"); + tracing::debug!( + "Middleware `{name}` returned a HTTP response, exit middleware loop" + ); return Err(pyresponse.into()); } } Err(e) => { - tracing::debug!("Middleware `{name}` returned an error, exit middleware loop"); + tracing::debug!( + "Middleware `{name}` returned an error, exit middleware loop" + ); return Err(e.into_response(protocol)); } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index dcda64bf3d..cba5a739e8 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -3,24 +3,24 @@ use std::{ task::{Context, Poll}, }; -use aws_smithy_http_server::protocols::Protocol; +use aws_smithy_http_server::{body::{Body, BoxBody}, protocols::Protocol}; use futures::{ready, Future}; use http::{Request, Response}; use pin_project_lite::pin_project; use pyo3_asyncio::TaskLocals; use tower::{Layer, Service}; -use super::PyMiddlewareTrait; +use crate::{PyMiddlewares, middleware::PyFuture}; #[derive(Debug, Clone)] -pub struct PyMiddlewareLayer { - handler: T, +pub struct PyMiddlewareLayer { + handlers: PyMiddlewares, protocol: Protocol, locals: TaskLocals, } -impl PyMiddlewareLayer { - pub fn new(handler: T, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { +impl PyMiddlewareLayer { + pub fn new(handlers: PyMiddlewares, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { let protocol = match protocol { "aws.protocols#restJson1" => Protocol::RestJson1, "aws.protocols#restXml" => Protocol::RestXml, @@ -29,63 +29,69 @@ impl PyMiddlewareLayer { _ => panic!(), }; Self { - handler, + handlers, protocol, locals, } } } -impl Layer for PyMiddlewareLayer -where - T: Clone, -{ - type Service = PyMiddlewareService; +impl Layer for PyMiddlewareLayer { + type Service = PyMiddlewareService; fn layer(&self, inner: S) -> Self::Service { - PyMiddlewareService::new(inner, self.handler.clone(), self.protocol, self.locals.clone()) + PyMiddlewareService::new( + inner, + self.handlers.clone(), + self.protocol, + self.locals.clone(), + ) } } #[derive(Clone, Debug)] -pub struct PyMiddlewareService { +pub struct PyMiddlewareService { inner: S, - handler: T, + handlers: PyMiddlewares, protocol: Protocol, - locals: TaskLocals + locals: TaskLocals, } -impl PyMiddlewareService { - pub fn new(inner: S, handler: T, protocol: Protocol, locals: TaskLocals) -> PyMiddlewareService { +impl PyMiddlewareService { + pub fn new( + inner: S, + handlers: PyMiddlewares, + protocol: Protocol, + locals: TaskLocals, + ) -> PyMiddlewareService { Self { inner, - handler, + handlers, protocol, - locals + locals, } } - pub fn layer(handler: T, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { - PyMiddlewareLayer::new(handler, protocol, locals) + pub fn layer(handlers: PyMiddlewares, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { + PyMiddlewareLayer::new(handlers, protocol, locals) } } -impl Service> for PyMiddlewareService +impl Service> for PyMiddlewareService where - M: PyMiddlewareTrait, - S: Service, Response = Response> + Clone, + S: Service, Response = Response> + Clone, { - type Response = Response; + type Response = Response; type Error = S::Error; - type Future = ResponseFuture; + type Future = ResponseFuture; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, req: Request) -> Self::Future { let inner = self.inner.clone(); - let run = self.handler.run(req, self.protocol, self.locals.clone()); + let run = self.handlers.run(req, self.protocol, self.locals.clone()); ResponseFuture { middleware: State::Running { run }, @@ -95,13 +101,12 @@ where } pin_project! { - pub struct ResponseFuture + pub struct ResponseFuture where - M: PyMiddlewareTrait, - S: Service>, + S: Service>, { #[pin] - middleware: State, + middleware: State, service: S, } } @@ -120,12 +125,11 @@ pin_project! { } } -impl Future for ResponseFuture +impl Future for ResponseFuture where - M: PyMiddlewareTrait, - S: Service, Response = Response>, + S: Service, Response = Response>, { - type Output = Result, S::Error>; + type Output = Result, S::Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let mut this = self.project(); diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index 24fc38dadb..a887d0703f 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -3,20 +3,16 @@ mod layer; mod request; mod response; +use aws_smithy_http_server::body::{Body, BoxBody}; use aws_smithy_http_server::protocols::Protocol; use futures::Future; +use futures::future::BoxFuture; use http::{Request, Response}; use pyo3_asyncio::TaskLocals; -pub use self::handler::{PyMiddlewareHandler, PyMiddlewareHandlers}; +pub use self::handler::{PyMiddlewareHandler, PyMiddlewares}; pub use self::layer::PyMiddlewareLayer; pub use self::request::{PyRequest, PyHttpVersion}; pub use self::response::PyResponse; -pub trait PyMiddlewareTrait { - type RequestBody; - type ResponseBody; - type Future: Future, Response>>; - - fn run(&mut self, request: Request, protocol: Protocol, locals: TaskLocals) -> Self::Future; -} +pub type PyFuture = BoxFuture<'static, Result, Response>>; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index c4f6e89a29..cc4e011f67 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use aws_smithy_http_server::body::Body; use http::{Request, Version}; use pyo3::prelude::*; @@ -11,7 +12,7 @@ pub enum PyHttpVersion { Http11, H2, H3, - __NonExhaustive + __NonExhaustive, } #[pyclass(name = "Request")] @@ -26,7 +27,7 @@ pub struct PyRequest { } impl PyRequest { - pub fn new(request: &Request) -> Self { + pub fn new(request: &Request) -> Self { Self { method: request.method().to_string(), uri: request.uri().to_string(), diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 17d63ea227..52b4459ae4 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -13,7 +13,7 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; -use crate::{middleware::PyMiddlewareHandler, PyMiddlewareHandlers, PySocket}; +use crate::{middleware::PyMiddlewareHandler, PyMiddlewares, PySocket}; /// A Python handler function representation. /// @@ -61,7 +61,7 @@ pub trait PyApp: Clone + pyo3::IntoPy { /// Mapping between operation names and their `PyHandler` representation. fn handlers(&mut self) -> &mut HashMap; - fn middlewares(&mut self) -> &mut PyMiddlewareHandlers; + fn middlewares(&mut self) -> &mut PyMiddlewares; fn protocol(&self) -> &'static str; From 28e3c09fac515708cc666fa19a814537a49758ca Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 19:34:49 -0700 Subject: [PATCH 10/30] Add testing of middleware handlers --- .../generators/PythonApplicationGenerator.kt | 2 +- .../examples/pokemon_service.py | 12 +- .../aws-smithy-http-server-python/src/lib.rs | 17 +- .../src/middleware/handler.rs | 188 ++++++++++++++++++ .../src/middleware/layer.rs | 12 +- .../src/middleware/mod.rs | 10 +- .../src/middleware/request.rs | 5 + .../src/middleware/response.rs | 12 +- .../src/server.rs | 11 +- 9 files changed, 245 insertions(+), 24 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index d988564a44..1d3934e028 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -281,7 +281,7 @@ class PythonApplicationGenerator( ##[pyo3(text_signature = "(${'$'}self, func)")] pub fn $name(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { use #{SmithyPython}::PyApp; - self.register_operation(py, func) + self.register_operation(py, "$name", func) } """, *codegenScope, diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 86ce9a0125..fa8a6c0a66 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -139,11 +139,7 @@ def check_content_type_header(request: Request): if content_type == "application/json": logging.debug("Found valid `application/json` content type") else: - logging.error(f"Invalid content type: {content_type}") - # Return an HTTP 401 Unauthorized if the content type is not JSON. - raise MiddlewareException("Invalid content type", 401) - # Check that `x-amzn-answer` header is not set. - assert request.get_header("x-amzn-answer") is None + logging.warning(f"Invalid content type: {content_type}") # This middleware adds a new header called `x-amazon-answer` to the @@ -164,9 +160,11 @@ async def check_method_and_content_length(request: Request): content_length = int(content_length) logging.debug("Request content length: {content_length}") else: - logging.error(f"Invalid content length. Dumping headers: {request.headers()}") + logging.warning(f"Invalid content length. Dumping headers: {request.headers()}") # Check that `x-amzn-answer` is 42. - assert request.get_header("x-amzn-answer") == "42" + if request.get_header("x-amzn-answer") != "42": + # Return an HTTP 401 Unauthorized if the content type is not JSON. + raise MiddlewareException("Invalid answer", 401) ########################################################### diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index b455f03d7a..687769774a 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -13,7 +13,7 @@ mod error; pub mod logging; -mod middleware; +pub mod middleware; mod server; mod socket; pub mod types; @@ -23,7 +23,7 @@ pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] pub use logging::LogLevel; #[doc(inline)] -pub use middleware::{PyMiddlewares, PyMiddlewareLayer, PyRequest, PyResponse, PyHttpVersion}; +pub use middleware::{PyHttpVersion, PyMiddlewareLayer, PyMiddlewares, PyRequest, PyResponse}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] @@ -33,11 +33,22 @@ pub use socket::PySocket; mod tests { use std::sync::Once; + use pyo3::{PyErr, Python}; + use pyo3_asyncio::TaskLocals; + static INIT: Once = Once::new(); - pub(crate) fn initialize() { + pub(crate) fn initialize() -> TaskLocals { INIT.call_once(|| { pyo3::prepare_freethreaded_python(); }); + + Python::with_gil(|py| { + let asyncio = py.import("asyncio")?; + let event_loop = asyncio.call_method0("new_event_loop")?; + asyncio.call_method1("set_event_loop", (event_loop,))?; + Ok::(TaskLocals::new(event_loop)) + }) + .unwrap() } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index 5702f9dede..7c814f90f2 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -1,3 +1,9 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +//! Execute Python middleware handlers. use aws_smithy_http_server::body::Body; use http::Request; use pyo3::prelude::*; @@ -124,3 +130,185 @@ impl PyMiddlewares { }) } } + +#[cfg(test)] +mod tests { + use http::HeaderValue; + use hyper::body::to_bytes; + use pretty_assertions::assert_eq; + + use super::*; + + #[tokio::test] + async fn middleware_chain_keeps_headers_changes() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def first_middleware(request: Request): + request.set_header("x-amzn-answer", "42") + return request + +async def second_middleware(request: Request): + if request.get_header("x-amzn-answer") != "42": + raise MiddlewareException("wrong answer", 401) +"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let first_middleware = PyMiddlewareHandler { + func: middleware.getattr("first_middleware")?.into_py(py), + is_coroutine: false, + name: "first".to_string(), + }; + all.append("first_middleware")?; + middlewares.push(first_middleware); + let second_middleware = PyMiddlewareHandler { + func: middleware.getattr("second_middleware")?.into_py(py), + is_coroutine: false, + name: "second".to_string(), + }; + all.append("second_middleware")?; + middlewares.push(second_middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap(); + assert_eq!( + result.headers().get("x-amzn-answer"), + Some(&HeaderValue::from_static("42")) + ); + Ok(()) + } + + #[tokio::test] + async fn middleware_return_response() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def middleware(request: Request): + return Response(200, {}, b"something")"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let middleware = PyMiddlewareHandler { + func: middleware.getattr("middleware")?.into_py(py), + is_coroutine: false, + name: "middleware".to_string(), + }; + all.append("middleware")?; + middlewares.push(middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap_err(); + assert_eq!(result.status(), 200); + let body = to_bytes(result.into_body()).await.unwrap(); + assert_eq!(body, "something".as_bytes()); + Ok(()) + } + + #[tokio::test] + async fn middleware_raise_middleware_exception() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def middleware(request: Request): + raise MiddlewareException("error", 503)"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let middleware = PyMiddlewareHandler { + func: middleware.getattr("middleware")?.into_py(py), + is_coroutine: false, + name: "middleware".to_string(), + }; + all.append("middleware")?; + middlewares.push(middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap_err(); + assert_eq!(result.status(), 503); + assert_eq!( + result.headers().get("X-Amzn-Errortype"), + Some(&HeaderValue::from_static("MiddlewareException")) + ); + let body = to_bytes(result.into_body()).await.unwrap(); + assert_eq!(body, r#"{"message":"error"}"#.as_bytes()); + Ok(()) + } + + #[tokio::test] + async fn middleware_raise_python_exception() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::from_code( + py, + r#" +def middleware(request): + raise ValueError("error")"#, + "", + "", + )?; + let middleware = PyMiddlewareHandler { + func: middleware.getattr("middleware")?.into_py(py), + is_coroutine: false, + name: "middleware".to_string(), + }; + middlewares.push(middleware); + Ok::<(), PyErr>(()) + })?; + + let result = middlewares + .run( + Request::builder().body(Body::from("")).unwrap(), + Protocol::RestJson1, + locals, + ) + .await + .unwrap_err(); + assert_eq!(result.status(), 500); + assert_eq!( + result.headers().get("X-Amzn-Errortype"), + Some(&HeaderValue::from_static("MiddlewareException")) + ); + let body = to_bytes(result.into_body()).await.unwrap(); + assert_eq!(body, r#"{"message":"ValueError: error"}"#.as_bytes()); + Ok(()) + } +} diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index cba5a739e8..98b64f9792 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -1,16 +1,24 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + use std::{ pin::Pin, task::{Context, Poll}, }; -use aws_smithy_http_server::{body::{Body, BoxBody}, protocols::Protocol}; +use aws_smithy_http_server::{ + body::{Body, BoxBody}, + protocols::Protocol, +}; use futures::{ready, Future}; use http::{Request, Response}; use pin_project_lite::pin_project; use pyo3_asyncio::TaskLocals; use tower::{Layer, Service}; -use crate::{PyMiddlewares, middleware::PyFuture}; +use crate::{middleware::PyFuture, PyMiddlewares}; #[derive(Debug, Clone)] pub struct PyMiddlewareLayer { diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index a887d0703f..92b9eb8b36 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -1,18 +1,20 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + mod handler; mod layer; mod request; mod response; use aws_smithy_http_server::body::{Body, BoxBody}; -use aws_smithy_http_server::protocols::Protocol; -use futures::Future; use futures::future::BoxFuture; use http::{Request, Response}; -use pyo3_asyncio::TaskLocals; pub use self::handler::{PyMiddlewareHandler, PyMiddlewares}; pub use self::layer::PyMiddlewareLayer; -pub use self::request::{PyRequest, PyHttpVersion}; +pub use self::request::{PyHttpVersion, PyRequest}; pub use self::response::PyResponse; pub type PyFuture = BoxFuture<'static, Result, Response>>; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index cc4e011f67..f3fcf83bcd 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + use std::collections::HashMap; use aws_smithy_http_server::body::Body; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs index 51c412e64f..786aa949e9 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -1,3 +1,8 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + use std::{collections::HashMap, convert::TryInto}; use aws_smithy_http_server::body::{to_boxed, BoxBody}; @@ -41,12 +46,15 @@ impl PyResponse { impl From for Response { fn from(pyresponse: PyResponse) -> Self { let mut response = Response::builder() - .status(StatusCode::from_u16(pyresponse.status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + .status( + StatusCode::from_u16(pyresponse.status) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + ) .body(to_boxed(pyresponse.body)) .unwrap_or_default(); match (&pyresponse.headers).try_into() { Ok(headers) => *response.headers_mut() = headers, - Err(e) => tracing::error!("Error extracting HTTP headers from PyResponse: {e}") + Err(e) => tracing::error!("Error extracting HTTP headers from PyResponse: {e}"), }; response } diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 52b4459ae4..1125ef0b37 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -293,10 +293,9 @@ event_loop.add_signal_handler(signal.SIGINT, /// such has if the registered function needs to be awaited (if it is a coroutine) and /// the number of arguments available, which tells us if the handler wants the state to be /// passed or not. - fn register_operation(&mut self, py: Python, func: PyObject) -> PyResult<()> { - let name = func.getattr(py, "__name__")?.extract::(py)?; + fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> { let is_coroutine = self.is_coroutine(py, &func)?; - // Find number of expected methods (a Pythzzon implementation could not accept the context). + // Find number of expected methods (a Python implementation could not accept the context). let inspect = py.import("inspect")?; let func_args = inspect .call_method1("getargs", (func.getattr(py, "__code__")?,))? @@ -313,7 +312,7 @@ event_loop.add_signal_handler(signal.SIGINT, handler.args, ); // Insert the handler in the handlers map. - self.handlers().insert(name, handler); + self.handlers().insert(name.to_string(), handler); Ok(()) } @@ -358,7 +357,7 @@ event_loop.add_signal_handler(signal.SIGINT, /// ```no_run /// use std::collections::HashMap; /// use pyo3::prelude::*; - /// use aws_smithy_http_server_python::{PyApp, PyHandler}; + /// use aws_smithy_http_server_python::{PyApp, PyHandler, PyMiddlewares}; /// use parking_lot::Mutex; /// /// #[pyclass] @@ -373,6 +372,8 @@ event_loop.add_signal_handler(signal.SIGINT, /// fn workers(&self) -> &Mutex> { todo!() } /// fn context(&self) -> &Option { todo!() } /// fn handlers(&mut self) -> &mut HashMap { todo!() } + /// fn middlewares(&mut self) -> &mut PyMiddlewares { todo!() } + /// fn protocol(&self) -> &'static str { "proto1" } /// } /// /// #[pymethods] From f3e21de1fe5b474cffda18321f56c6f9106cc97b Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 20:17:09 -0700 Subject: [PATCH 11/30] Add end to end test of the service --- .../generators/PythonApplicationGenerator.kt | 3 +- .../generators/PythonServerModuleGenerator.kt | 2 +- .../aws-smithy-http-server-python/Cargo.toml | 2 +- .../src/middleware/handler.rs | 2 +- .../src/middleware/layer.rs | 75 ++++++++++++++++++- 5 files changed, 75 insertions(+), 9 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 1d3934e028..0248aab7c0 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -7,18 +7,17 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.traits.DocumentationTrait +import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.client.rustlang.RustType import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter import software.amazon.smithy.rust.codegen.client.rustlang.asType import software.amazon.smithy.rust.codegen.client.rustlang.rust import software.amazon.smithy.rust.codegen.client.rustlang.rustBlockTemplate import software.amazon.smithy.rust.codegen.client.rustlang.rustTemplate -import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.client.smithy.CoreCodegenContext import software.amazon.smithy.rust.codegen.client.smithy.Errors import software.amazon.smithy.rust.codegen.client.smithy.Inputs import software.amazon.smithy.rust.codegen.client.smithy.Outputs -import software.amazon.smithy.rust.codegen.client.rustlang.escape import software.amazon.smithy.rust.codegen.core.util.getTrait import software.amazon.smithy.rust.codegen.core.util.inputShape import software.amazon.smithy.rust.codegen.core.util.outputShape diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index ad639aa5b6..76a0b90314 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -141,7 +141,7 @@ class PythonServerModuleGenerator( ); m.add_submodule(middleware)?; """, - *codegenScope + *codegenScope, ) } diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index 2be120e75e..cde29acc9a 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -33,7 +33,7 @@ socket2 = { version = "0.4.4", features = ["all"] } thiserror = "1.0.32" tokio = { version = "1.20.1", features = ["full"] } tokio-stream = "0.1" -tower = "0.4.13" +tower = { version = "0.4.13", features = ["util"] } tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index 7c814f90f2..e1a3645008 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -23,7 +23,7 @@ pub struct PyMiddlewareHandler { } #[derive(Debug, Clone, Default)] -pub struct PyMiddlewares(Vec); +pub struct PyMiddlewares(pub Vec); impl PyMiddlewares { pub fn new(handlers: Vec) -> Self { diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index 98b64f9792..149167a18c 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -79,10 +79,6 @@ impl PyMiddlewareService { locals, } } - - pub fn layer(handlers: PyMiddlewares, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { - PyMiddlewareLayer::new(handlers, protocol, locals) - } } impl Service> for PyMiddlewareService @@ -158,3 +154,74 @@ where } } } + +#[cfg(test)] +mod tests { + use std::error::Error; + + use super::*; + + use aws_smithy_http_server::body::to_boxed; + use pyo3::prelude::*; + use tower::{Service, ServiceBuilder, ServiceExt}; + + use crate::middleware::PyMiddlewareHandler; + use crate::{PyMiddlewareException, PyRequest}; + + async fn echo(req: Request) -> Result, Box> { + Ok(Response::new(to_boxed(req.into_body()))) + } + + #[tokio::test] + async fn test_middlewares_are_chained_inside_layer() -> PyResult<()> { + let locals = crate::tests::initialize(); + let mut middlewares = PyMiddlewares(vec![]); + + Python::with_gil(|py| { + let middleware = PyModule::new(py, "middleware").unwrap(); + middleware.add_class::().unwrap(); + middleware.add_class::().unwrap(); + let pycode = r#" +def first_middleware(request: Request): + request.set_header("x-amzn-answer", "42") + return request + +def second_middleware(request: Request): + if request.get_header("x-amzn-answer") != "42": + raise MiddlewareException("wrong answer", 401) +"#; + py.run(pycode, Some(middleware.dict()), None)?; + let all = middleware.index()?; + let first_middleware = PyMiddlewareHandler { + func: middleware.getattr("first_middleware")?.into_py(py), + is_coroutine: false, + name: "first".to_string(), + }; + all.append("first_middleware")?; + middlewares.push(first_middleware); + let second_middleware = PyMiddlewareHandler { + func: middleware.getattr("second_middleware")?.into_py(py), + is_coroutine: false, + name: "second".to_string(), + }; + all.append("second_middleware")?; + middlewares.push(second_middleware); + Ok::<(), PyErr>(()) + })?; + + let mut service = ServiceBuilder::new() + .layer(PyMiddlewareLayer::new( + middlewares, + "aws.protocols#restJson1", + locals, + )) + .service_fn(echo); + + let request = Request::get("/").body(Body::empty()).unwrap(); + + let res = service.ready().await.unwrap().call(request).await.unwrap(); + + assert_eq!(res.status(), 200); + Ok(()) + } +} From 84cc06bdef914eaa2fa8133c63b712259f3dceda Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 20:19:50 -0700 Subject: [PATCH 12/30] Add end to end test of the layer --- .../aws-smithy-http-server-python/src/middleware/handler.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index e1a3645008..b6cc3da9f4 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -153,7 +153,7 @@ def first_middleware(request: Request): request.set_header("x-amzn-answer", "42") return request -async def second_middleware(request: Request): +def second_middleware(request: Request): if request.get_header("x-amzn-answer") != "42": raise MiddlewareException("wrong answer", 401) "#; From 38d971c673d80b6606e98df201bb4f4ec6113047 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 20:21:53 -0700 Subject: [PATCH 13/30] Remove useless dependency --- .../python/smithy/generators/PythonApplicationGenerator.kt | 2 -- 1 file changed, 2 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 0248aab7c0..6d20a27876 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -7,7 +7,6 @@ package software.amazon.smithy.rust.codegen.server.python.smithy.generators import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.traits.DocumentationTrait -import software.amazon.smithy.rust.codegen.client.rustlang.CargoDependency import software.amazon.smithy.rust.codegen.client.rustlang.RustType import software.amazon.smithy.rust.codegen.client.rustlang.RustWriter import software.amazon.smithy.rust.codegen.client.rustlang.asType @@ -74,7 +73,6 @@ class PythonApplicationGenerator( arrayOf( "SmithyPython" to PythonServerCargoDependency.SmithyHttpServerPython(runtimeConfig).asType(), "SmithyServer" to ServerCargoDependency.SmithyHttpServer(runtimeConfig).asType(), - "http" to CargoDependency.Http.asType(), "pyo3" to PythonServerCargoDependency.PyO3.asType(), "pyo3_asyncio" to PythonServerCargoDependency.PyO3Asyncio.asType(), "tokio" to PythonServerCargoDependency.Tokio.asType(), From cf7521decfe9b6fbef2b602e39e8437c68463e23 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Fri, 16 Sep 2022 20:24:29 -0700 Subject: [PATCH 14/30] Remove another useless dependency --- rust-runtime/aws-smithy-http-server-python/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index cde29acc9a..bf9cf99dca 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -20,7 +20,6 @@ aws-smithy-types = { path = "../aws-smithy-types" } aws-smithy-xml = { path = "../aws-smithy-xml" } bytes = "1.2" futures = "0.3" -futures-core = "0.3" http = "0.2" hyper = { version = "0.14.20", features = ["server", "http1", "http2", "tcp", "stream"] } num_cpus = "1.13.1" From 80cdbfc7ffac7fbc6ac66310922ef2639d7c2935 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Sat, 17 Sep 2022 15:07:04 -0700 Subject: [PATCH 15/30] Idiomatic logging refactoring --- .../smithy/PythonServerCargoDependency.kt | 1 + .../generators/PythonApplicationGenerator.kt | 20 +- .../generators/PythonServerModuleGenerator.kt | 19 ++ .../PythonServerOperationHandlerGenerator.kt | 9 + .../aws-smithy-http-server-python/Cargo.toml | 1 + .../examples/pokemon_service.py | 10 +- .../aws-smithy-http-server-python/src/lib.rs | 2 +- .../src/logging.rs | 270 ++++++++++-------- .../src/server.rs | 35 ++- 9 files changed, 214 insertions(+), 153 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt index d92482f324..2c37ceb14e 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt @@ -19,6 +19,7 @@ object PythonServerCargoDependency { val PyO3Asyncio: CargoDependency = CargoDependency("pyo3-asyncio", CratesIo("0.16"), features = setOf("attributes", "tokio-runtime")) val Tokio: CargoDependency = CargoDependency("tokio", CratesIo("1.20.1"), features = setOf("full")) val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1")) + val TracingAppender: CargoDependency = CargoDependency("tracing-appender", CratesIo("0.2")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TowerHttp: CargoDependency = CargoDependency("tower-http", CratesIo("0.3"), features = setOf("trace")) val Hyper: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), features = setOf("server", "http1", "http2", "tcp", "stream")) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 6d20a27876..5692d11275 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -77,6 +77,7 @@ class PythonApplicationGenerator( "pyo3_asyncio" to PythonServerCargoDependency.PyO3Asyncio.asType(), "tokio" to PythonServerCargoDependency.Tokio.asType(), "tracing" to PythonServerCargoDependency.Tracing.asType(), + "tracing_appender" to PythonServerCargoDependency.TracingAppender.asType(), "tower" to PythonServerCargoDependency.Tower.asType(), "tower_http" to PythonServerCargoDependency.TowerHttp.asType(), "num_cpus" to PythonServerCargoDependency.NumCpus.asType(), @@ -104,6 +105,8 @@ class PythonApplicationGenerator( middlewares: #{SmithyPython}::PyMiddlewares, context: Option<#{pyo3}::PyObject>, workers: #{parking_lot}::Mutex>, + _tracing_guard: Option<#{tracing_appender}::non_blocking::WorkerGuard>, + logfile: Option } """, *codegenScope, @@ -120,6 +123,8 @@ class PythonApplicationGenerator( middlewares: self.middlewares.clone(), context: self.context.clone(), workers: #{parking_lot}::Mutex::new(vec![]), + _tracing_guard: None, + logfile: self.logfile.clone() } } } @@ -223,10 +228,14 @@ class PythonApplicationGenerator( """ /// Create a new [App]. ##[new] - pub fn new(py: #{pyo3}::Python, log_level: Option<#{SmithyPython}::LogLevel>) -> #{pyo3}::PyResult { - let log_level = log_level.unwrap_or(#{SmithyPython}::LogLevel::Info); - #{SmithyPython}::logging::setup(py, log_level)?; - Ok(Self::default()) + pub fn new(py: #{pyo3}::Python, logfile: Option<&#{pyo3}::PyAny>) -> #{pyo3}::PyResult { + let logfile = if let Some(logfile) = logfile { + let logfile = logfile.extract::<&str>()?; + Some(std::path::Path::new(logfile).to_path_buf()) + } else { + None + }; + Ok(Self { logfile, ..Default::default() }) } /// Register a context object that will be shared between handlers. ##[pyo3(text_signature = "(${'$'}self, context)")] @@ -257,10 +266,11 @@ class PythonApplicationGenerator( pub fn start_worker( &mut self, py: pyo3::Python, - socket: &pyo3::PyCell, + socket: &pyo3::PyCell<#{SmithyPython}::PySocket>, worker_number: isize, ) -> pyo3::PyResult<()> { use #{SmithyPython}::PyApp; + self._tracing_guard = #{SmithyPython}::logging::setup_tracing(py, self.logfile.as_ref())?; let event_loop = self.configure_python_event_loop(py)?; let router = self.build_router(event_loop)?; self.start_hyper_worker(py, socket, event_loop, router, worker_number) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 76a0b90314..09fb6ec25a 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -47,6 +47,7 @@ class PythonServerModuleGenerator( renderPyCodegeneratedTypes() renderPyWrapperTypes() renderPySocketType() + renderPyLogging() renderPyMiddlewareTypes() renderPyApplicationType() } @@ -126,6 +127,24 @@ class PythonServerModuleGenerator( ) } + // Render Python shared socket type. + private fun RustWriter.renderPyLogging() { + rustTemplate( + """ + let logging = #{pyo3}::types::PyModule::new(py, "logging")?; + logging.add_function(#{pyo3}::wrap_pyfunction!(#{SmithyPython}::py_tracing_event, m)?)?; + logging.add_class::<#{SmithyPython}::PyTracingHandler>()?; + #{pyo3}::py_run!( + py, + logging, + "import sys; sys.modules['$libName.logging'] = logging" + ); + m.add_submodule(logging)?; + """, + *codegenScope, + ) + } + private fun RustWriter.renderPyMiddlewareTypes() { rustTemplate( """ diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index a158d67722..709df4434b 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -69,11 +69,20 @@ class PythonServerOperationHandlerGenerator( handler: #{SmithyPython}::PyHandler, ) -> std::result::Result<$output, $error> { // Async block used to run the handler and catch any Python error. + let span = #{tracing}::span!( + #{tracing}::Level::TRACE, "python", + pid = #{tracing}::field::Empty, + module = #{tracing}::field::Empty, + filename = #{tracing}::field::Empty, + lineno = #{tracing}::field::Empty + ); + let guard = span.enter(); let result = if handler.is_coroutine { #{PyCoroutine:W} } else { #{PyFunction:W} }; + drop(guard); #{PyError:W} } """, diff --git a/rust-runtime/aws-smithy-http-server-python/Cargo.toml b/rust-runtime/aws-smithy-http-server-python/Cargo.toml index bf9cf99dca..b77292c324 100644 --- a/rust-runtime/aws-smithy-http-server-python/Cargo.toml +++ b/rust-runtime/aws-smithy-http-server-python/Cargo.toml @@ -35,6 +35,7 @@ tokio-stream = "0.1" tower = { version = "0.4.13", features = ["util"] } tracing = "0.1.36" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } +tracing-appender = { version = "0.2.2"} [dev-dependencies] pretty_assertions = "1" diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index fa8a6c0a66..fdb43de069 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -11,12 +11,12 @@ from typing import List, Optional import aiohttp - from libpokemon_service_server_sdk import App from libpokemon_service_server_sdk.error import ResourceNotFoundException from libpokemon_service_server_sdk.input import ( EmptyOperationInput, GetPokemonSpeciesInput, GetServerStatisticsInput, HealthCheckOperationInput, StreamPokemonRadioOperationInput) +from libpokemon_service_server_sdk.logging import TracingHandler from libpokemon_service_server_sdk.middleware import (MiddlewareException, Request) from libpokemon_service_server_sdk.model import FlavorText, Language @@ -25,6 +25,10 @@ HealthCheckOperationOutput, StreamPokemonRadioOperationOutput) from libpokemon_service_server_sdk.types import ByteStream +# Logging can bee setup using standard Python tooling. We provide +# fast logging handler, Tracingandler based on Rust tracing crate. +logging.basicConfig(level=logging.INFO, handlers=[TracingHandler.handle()]) + # A slightly more atomic counter using a threading lock. class FastWriteCounter: @@ -185,7 +189,7 @@ def get_pokemon_species( context.increment_calls_count() flavor_text_entries = context.get_pokemon_description(input.name) if flavor_text_entries: - logging.debug("Total requests executed: %s", context.get_calls_count()) + logging.error("Total requests executed: %s", context.get_calls_count()) logging.info("Found description for Pokémon %s", input.name) return GetPokemonSpeciesOutput( name=input.name, flavor_text_entries=flavor_text_entries @@ -226,4 +230,4 @@ async def stream_pokemon_radio(_: StreamPokemonRadioOperationInput, context: Con ########################################################### # Run the server. ########################################################### -app.run(workers=1) +app.run(workers=3) diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index 687769774a..c5b4f838cb 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -21,7 +21,7 @@ pub mod types; #[doc(inline)] pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] -pub use logging::LogLevel; +pub use logging::{py_tracing_event, PyTracingHandler}; #[doc(inline)] pub use middleware::{PyHttpVersion, PyMiddlewareLayer, PyMiddlewares, PyRequest, PyResponse}; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 4c690f3f7f..c8bbaabc4f 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -5,86 +5,95 @@ //! Rust `tracing` and Python `logging` setup and utilities. -use pyo3::prelude::*; -use tracing::Level; -use tracing_subscriber::filter::LevelFilter; -use tracing_subscriber::{prelude::*, EnvFilter}; +use std::path::PathBuf; -/// Setup `tracing::subscriber` reading the log level from RUST_LOG environment variable -/// and inject the custom Python `logger` into the interpreter. -pub fn setup(py: Python, level: LogLevel) -> PyResult<()> { - let format = tracing_subscriber::fmt::layer() - .with_ansi(true) - .with_line_number(true) - .with_level(true); - match EnvFilter::try_from_default_env() { - Ok(filter) => { - let level: LogLevel = filter.to_string().into(); - tracing_subscriber::registry() - .with(format) - .with(filter) - .init(); - setup_python_logging(py, level)?; - } - Err(_) => { - tracing_subscriber::registry() - .with(format) - .with(LevelFilter::from_level(level.into())) - .init(); - setup_python_logging(py, level)?; +use pyo3::prelude::*; +use tracing::{Level, Span}; +use tracing_appender::non_blocking::WorkerGuard; +use tracing_subscriber::{ + fmt::{self, writer::MakeWriterExt}, + layer::SubscriberExt, + util::SubscriberInitExt, +}; + +use crate::error::PyException; + +/// Setup `tracing::subscriber` reading the log level from RUST_LOG environment variable. +/// If the variable is not set, the logging for both Python and Rust will be set at the +/// level used by Python logging module. +pub fn setup_tracing(py: Python, logfile: Option<&PathBuf>) -> PyResult> { + let logging = py.import("logging")?; + let root = logging.getattr("root")?; + let handlers = root.getattr("handlers")?; + let handlers = handlers.extract::>()?; + for handler in handlers.iter() { + let name = handler.getattr(py, "__name__")?; + if let Ok(name) = name.extract::<&str>(py) { + if name == "SmithyRsTracingHandler" { + return setup_tracing_subscriber(py, logfile); + } } } - Ok(()) -} - -/// This custom logger enum exported to Python can be used to configure the -/// both the Rust `tracing` and Python `logging` levels. -/// We cannot export directly `tracing::Level` to Python. -#[pyclass] -#[derive(Debug, Clone, Copy)] -pub enum LogLevel { - Trace, - Debug, - Info, - Warn, - Error, + Ok(None) } -/// `From` is used to convert `LogLevel` to the correct string -/// needed by Python `logging` module. -impl From for String { - fn from(other: LogLevel) -> String { - match other { - LogLevel::Error => "ERROR".into(), - LogLevel::Warn => "WARN".into(), - LogLevel::Info => "INFO".into(), - _ => "DEBUG".into(), +fn setup_tracing_subscriber( + py: Python, + logfile: Option<&PathBuf>, +) -> PyResult> { + let appender = match logfile { + Some(logfile) => { + let parent = logfile.parent().ok_or_else(|| { + PyException::new_err(format!( + "Tracing setup failed: unable to extract dirname from path {}", + logfile.display() + )) + })?; + let filename = logfile.file_name().ok_or_else(|| { + PyException::new_err(format!( + "Tracing setup failed: unable to extract basename from path {}", + logfile.display() + )) + })?; + let file_appender = tracing_appender::rolling::hourly(parent, filename); + let (appender, guard) = tracing_appender::non_blocking(file_appender); + Some((appender, guard)) } - } -} + None => None, + }; -/// `From` is used to covert `tracing::EnvFilter` into `LogLevel`. -impl From for LogLevel { - fn from(other: String) -> LogLevel { - match other.as_str() { - "error" => LogLevel::Error, - "warn" => LogLevel::Warn, - "info" => LogLevel::Info, - "debug" => LogLevel::Debug, - _ => LogLevel::Trace, + let logging = py.import("logging")?; + let root = logging.getattr("root")?; + let level: u8 = root.getattr("level")?.extract()?; + let level = match level { + 40u8 => Level::ERROR, + 30u8 => Level::WARN, + 20u8 => Level::INFO, + 10u8 => Level::DEBUG, + _ => Level::TRACE, + }; + match appender { + Some((appender, guard)) => { + let layer = Some( + fmt::Layer::new() + .with_writer(appender.with_max_level(level)) + .with_ansi(true) + .with_line_number(true) + .with_level(true), + ); + tracing_subscriber::registry().with(layer).init(); + Ok(Some(guard)) } - } -} - -/// `From` is used to covert `LogLevel` into `tracing::EnvFilter`. -impl From for Level { - fn from(other: LogLevel) -> Level { - match other { - LogLevel::Debug => Level::DEBUG, - LogLevel::Info => Level::INFO, - LogLevel::Warn => Level::WARN, - LogLevel::Error => Level::ERROR, - _ => Level::TRACE, + None => { + let layer = Some( + fmt::Layer::new() + .with_writer(std::io::stdout.with_max_level(level)) + .with_ansi(true) + .with_line_number(true) + .with_level(true), + ); + tracing_subscriber::registry().with(layer).init(); + Ok(None) } } } @@ -99,70 +108,79 @@ impl From for Level { /// /// Since any call like `logging.warn(...)` sets up logging via `logging.basicConfig`, all log messages are now /// delivered to `crate::logging`, which will send them to `tracing::event!`. -fn setup_python_logging(py: Python, level: LogLevel) -> PyResult<()> { - let logging = py.import("logging")?; - logging.setattr("python_tracing", wrap_pyfunction!(python_tracing, logging)?)?; - - let level: String = level.into(); - let pycode = format!( - r#" -class RustTracing(Handler): +#[pyclass(name = "TracingHandler")] +#[derive(Debug, Clone)] +pub struct PyTracingHandler; + +#[pymethods] +impl PyTracingHandler { + #[staticmethod] + fn handle(py: Python) -> PyResult> { + let logging = py.import("logging")?; + logging.setattr( + "py_tracing_event", + wrap_pyfunction!(py_tracing_event, logging)?, + )?; + + let pycode = r#" +class TracingHandler(Handler): + __name__ = "SmithyRsTracingHandler" """ Python logging to Rust tracing handler. """ - def __init__(self, level=0): - super().__init__(level=level) - def emit(self, record): - python_tracing(record) - -# Store the old basicConfig in the local namespace. -oldBasicConfig = basicConfig - -def basicConfig(*pargs, **kwargs): - """ Reimplement basicConfig to hijack the root logger. """ - if "handlers" not in kwargs: - kwargs["handlers"] = [RustTracing()] - kwargs["level"] = {level} - return oldBasicConfig(*pargs, **kwargs) -"#, - ); - - py.run(&pycode, Some(logging.dict()), None)?; - let all = logging.index()?; - all.append("RustTracing")?; - Ok(()) + py_tracing_event( + record.levelno, record.getMessage(), record.module, + record.filename, record.lineno, record.process + ) +"#; + py.run(pycode, Some(logging.dict()), None)?; + let all = logging.index()?; + all.append("TracingHandler")?; + let handler = logging.getattr("TracingHandler")?; + Ok(handler.call0()?.into_py(py)) + } } /// Consumes a Python `logging.LogRecord` and emits a Rust [tracing::Event] instead. -#[cfg(not(test))] +// #[cfg(not(test))] #[pyfunction] -#[pyo3(text_signature = "(record)")] -fn python_tracing(record: &PyAny) -> PyResult<()> { - let level = record.getattr("levelno")?; - let message = record.getattr("getMessage")?.call0()?; - let module = record.getattr("module")?; - let filename = record.getattr("filename")?; - let line = record.getattr("lineno")?; - let pid = record.getattr("process")?; - - match level.extract()? { - 40u8 => tracing::event!(Level::ERROR, %pid, %module, %filename, %line, "{message}"), - 30u8 => tracing::event!(Level::WARN, %pid, %module, %filename, %line, "{message}"), - 20u8 => tracing::event!(Level::INFO, %pid, %module, %filename, %line, "{message}"), - 10u8 => tracing::event!(Level::DEBUG, %pid, %module, %filename, %line, "{message}"), +#[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] +pub fn py_tracing_event( + level: u8, + message: &str, + module: &str, + filename: &str, + line: usize, + pid: usize, +) -> PyResult<()> { + let span = Span::current(); + span.record("pid", pid); + span.record("module", module); + span.record("filename", filename); + span.record("lineno", line); + match level { + 40 => tracing::error!("{message}"), + 30 => tracing::warn!("{message}"), + 20 => tracing::info!("{message}"), + 10 => tracing::debug!("{message}"), _ => tracing::event!(Level::TRACE, %pid, %module, %filename, %line, "{message}"), }; - Ok(()) } -#[cfg(test)] -#[pyfunction] -#[pyo3(text_signature = "(record)")] -fn python_tracing(record: &PyAny) -> PyResult<()> { - let message = record.getattr("getMessage")?.call0()?; - pretty_assertions::assert_eq!(message.to_string(), "a message"); - Ok(()) -} +// #[cfg(test)] +// #[pyfunction] +// #[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] +// pub fn py_tracing_event( +// level: u8, +// message: &str, +// module: &str, +// filename: &str, +// line: usize, +// pid: usize, +// ) -> PyResult<()> { +// pretty_assertions::assert_eq!(message.to_string(), "a message"); +// Ok(()) +// } #[cfg(test)] mod tests { @@ -172,7 +190,7 @@ mod tests { fn tracing_handler_is_injected_in_python() { crate::tests::initialize(); Python::with_gil(|py| { - setup_python_logging(py, LogLevel::Info).unwrap(); + setup_tracing(py, None).unwrap(); let logging = py.import("logging").unwrap(); logging.call_method1("info", ("a message",)).unwrap(); }); diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 1125ef0b37..ab80cae4a5 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -12,8 +12,9 @@ use pyo3::{prelude::*, types::IntoPyDict}; use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; +use tracing_appender::non_blocking::WorkerGuard; -use crate::{middleware::PyMiddlewareHandler, PyMiddlewares, PySocket}; +use crate::{logging::setup_tracing, middleware::PyMiddlewareHandler, PyMiddlewares, PySocket}; /// A Python handler function representation. /// @@ -77,17 +78,15 @@ pub trait PyApp: Clone + pyo3::IntoPy { .getattr(py, "pid") .map(|pid| pid.extract(py).unwrap_or(-1)) .unwrap_or(-1); - tracing::debug!("Terminating worker {idx}, PID: {pid}"); + println!("Terminating worker {idx}, PID: {pid}"); match worker.call_method0(py, "terminate") { Ok(_) => {} Err(e) => { - tracing::error!("Error terminating worker {idx}, PID: {pid}: {e}"); + eprintln!("Error terminating worker {idx}, PID: {pid}: {e}"); worker .call_method0(py, "kill") .map_err(|e| { - tracing::error!( - "Unable to kill kill worker {idx}, PID: {pid}: {e}" - ); + eprintln!("Unable to kill kill worker {idx}, PID: {pid}: {e}"); }) .unwrap(); } @@ -108,11 +107,11 @@ pub trait PyApp: Clone + pyo3::IntoPy { .getattr(py, "pid") .map(|pid| pid.extract(py).unwrap_or(-1)) .unwrap_or(-1); - tracing::debug!("Killing worker {idx}, PID: {pid}"); + println!("Killing worker {idx}, PID: {pid}"); worker .call_method0(py, "kill") .map_err(|e| { - tracing::error!("Unable to kill kill worker {idx}, PID: {pid}: {e}"); + eprintln!("Unable to kill kill worker {idx}, PID: {pid}: {e}"); }) .unwrap(); }); @@ -135,20 +134,20 @@ pub trait PyApp: Clone + pyo3::IntoPy { for sig in signals.forever() { match sig { SIGINT => { - tracing::info!( + println!( "Termination signal {sig:?} received, all workers will be immediately terminated" ); self.immediate_termination(self.workers()); } SIGTERM | SIGQUIT => { - tracing::info!( + println!( "Termination signal {sig:?} received, all workers will be gracefully terminated" ); self.graceful_termination(self.workers()); } _ => { - tracing::warn!("Signal {sig:?} is ignored by this application"); + println!("Signal {sig:?} is ignored by this application"); } } } @@ -161,10 +160,12 @@ pub trait PyApp: Clone + pyo3::IntoPy { py.run( r#" import asyncio +import logging import functools import signal async def shutdown(sig, event_loop): + import asyncio import logging logging.info(f"Caught signal {sig.name}, cancelling tasks registered on this loop") tasks = [task for task in asyncio.all_tasks() if task is not @@ -278,10 +279,9 @@ event_loop.add_signal_handler(signal.SIGINT, func, is_coroutine, }; - tracing::info!( + println!( "Registering middleware function `{}`, coroutine: {}", - handler.name, - handler.is_coroutine, + handler.name, handler.is_coroutine, ); self.middlewares().push(handler); Ok(()) @@ -306,10 +306,9 @@ event_loop.add_signal_handler(signal.SIGINT, is_coroutine, args: func_args.len(), }; - tracing::info!( + println!( "Registering handler function `{name}`, coroutine: {}, arguments: {}", - handler.is_coroutine, - handler.args, + handler.is_coroutine, handler.args, ); // Insert the handler in the handlers map. self.handlers().insert(name.to_string(), handler); @@ -438,7 +437,7 @@ event_loop.add_signal_handler(signal.SIGINT, } // Unlock the workers mutex. drop(active_workers); - tracing::info!("Rust Python server started successfully"); + println!("Rust Python server started successfully"); self.block_on_rust_signals(); Ok(()) } From e79797fafc315574da2ecec406383ee37d6fda1d Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Sat, 17 Sep 2022 15:08:45 -0700 Subject: [PATCH 16/30] Enable back logging tests --- .../src/logging.rs | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index c8bbaabc4f..92ec0f97c0 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -141,7 +141,7 @@ class TracingHandler(Handler): } /// Consumes a Python `logging.LogRecord` and emits a Rust [tracing::Event] instead. -// #[cfg(not(test))] +#[cfg(not(test))] #[pyfunction] #[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] pub fn py_tracing_event( @@ -167,20 +167,20 @@ pub fn py_tracing_event( Ok(()) } -// #[cfg(test)] -// #[pyfunction] -// #[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] -// pub fn py_tracing_event( -// level: u8, -// message: &str, -// module: &str, -// filename: &str, -// line: usize, -// pid: usize, -// ) -> PyResult<()> { -// pretty_assertions::assert_eq!(message.to_string(), "a message"); -// Ok(()) -// } +#[cfg(test)] +#[pyfunction] +#[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] +pub fn py_tracing_event( + level: u8, + message: &str, + module: &str, + filename: &str, + line: usize, + pid: usize, +) -> PyResult<()> { + pretty_assertions::assert_eq!(message.to_string(), "a message"); + Ok(()) +} #[cfg(test)] mod tests { From b08d3313e5582f81e07b31a9f1ea1b2d50b3fbaa Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Mon, 19 Sep 2022 15:36:48 -0700 Subject: [PATCH 17/30] Make clippy happy --- rust-runtime/aws-smithy-http-server-python/src/server.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index ab80cae4a5..b348700506 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -12,9 +12,8 @@ use pyo3::{prelude::*, types::IntoPyDict}; use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; -use tracing_appender::non_blocking::WorkerGuard; -use crate::{logging::setup_tracing, middleware::PyMiddlewareHandler, PyMiddlewares, PySocket}; +use crate::{middleware::PyMiddlewareHandler, PyMiddlewares, PySocket}; /// A Python handler function representation. /// @@ -165,6 +164,8 @@ import functools import signal async def shutdown(sig, event_loop): + # reimport asyncio and logging to be sure they are available when + # this handler runs on signal catching. import asyncio import logging logging.info(f"Caught signal {sig.name}, cancelling tasks registered on this loop") From 84155036fb9738f74d7c6f11de7b0f34bf71673d Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Mon, 19 Sep 2022 16:07:31 -0700 Subject: [PATCH 18/30] Another nudge for clippy happyness --- .../python/smithy/generators/PythonApplicationGenerator.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 5692d11275..859e3e7703 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -228,7 +228,7 @@ class PythonApplicationGenerator( """ /// Create a new [App]. ##[new] - pub fn new(py: #{pyo3}::Python, logfile: Option<&#{pyo3}::PyAny>) -> #{pyo3}::PyResult { + pub fn new(logfile: Option<&#{pyo3}::PyAny>) -> #{pyo3}::PyResult { let logfile = if let Some(logfile) = logfile { let logfile = logfile.extract::<&str>()?; Some(std::path::Path::new(logfile).to_path_buf()) From 37036b8ea879f8f08ff8c2f2ce2eddbafaaf9344 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Mon, 19 Sep 2022 16:20:53 -0700 Subject: [PATCH 19/30] Clippy again --- .../aws-smithy-http-server-python/src/logging.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 92ec0f97c0..5ad053aedd 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -8,7 +8,7 @@ use std::path::PathBuf; use pyo3::prelude::*; -use tracing::{Level, Span}; +use tracing::Level; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{ fmt::{self, writer::MakeWriterExt}, @@ -171,12 +171,12 @@ pub fn py_tracing_event( #[pyfunction] #[pyo3(text_signature = "(level, record, message, module, filename, line, pid)")] pub fn py_tracing_event( - level: u8, + _level: u8, message: &str, - module: &str, - filename: &str, - line: usize, - pid: usize, + _module: &str, + _filename: &str, + _line: usize, + _pid: usize, ) -> PyResult<()> { pretty_assertions::assert_eq!(message.to_string(), "a message"); Ok(()) From 1d51bb8c3f2346d2bcd15feab3574beae59dd400 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Tue, 20 Sep 2022 10:39:47 +0100 Subject: [PATCH 20/30] Span needs to be only available for not tests --- rust-runtime/aws-smithy-http-server-python/src/logging.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 5ad053aedd..0765860135 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -9,6 +9,8 @@ use std::path::PathBuf; use pyo3::prelude::*; use tracing::Level; +#[cfg(not(test))] +use tracing::Span; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{ fmt::{self, writer::MakeWriterExt}, From 2a00ec2094296c9027e5b91de66cd0943c84716b Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Tue, 20 Sep 2022 11:49:55 +0100 Subject: [PATCH 21/30] Fix integration tests --- .../examples/pokemon_service.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index fdb43de069..8243417667 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -143,28 +143,21 @@ def check_content_type_header(request: Request): if content_type == "application/json": logging.debug("Found valid `application/json` content type") else: - logging.warning(f"Invalid content type: {content_type}") + logging.warning(f"Invalid content type {content_type}, dumping headers: {request.headers()}") # This middleware adds a new header called `x-amazon-answer` to the # request. We expect to see this header to be populated in the next # middleware. @app.middleware -def add_x_amzn_stuff_header(request: Request): +def add_x_amzn_answer_header(request: Request): request.set_header("x-amzn-answer", "42") - logging.debug("Setting `x-amzn-stuff` header") + logging.debug("Setting `x-amzn-answer` header to 42") return request @app.middleware -async def check_method_and_content_length(request: Request): - content_length = request.get_header("content-length") - logging.debug(f"Request method: {request.method}") - if content_length is not None: - content_length = int(content_length) - logging.debug("Request content length: {content_length}") - else: - logging.warning(f"Invalid content length. Dumping headers: {request.headers()}") +async def check_x_amzn_answer_header(request: Request): # Check that `x-amzn-answer` is 42. if request.get_header("x-amzn-answer") != "42": # Return an HTTP 401 Unauthorized if the content type is not JSON. @@ -230,4 +223,4 @@ async def stream_pokemon_radio(_: StreamPokemonRadioOperationInput, context: Con ########################################################### # Run the server. ########################################################### -app.run(workers=3) +app.run(workers=1) From e75d2db9353a31ed4a62f10f44855a090bd5ffbf Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Tue, 20 Sep 2022 16:09:29 +0100 Subject: [PATCH 22/30] Add documentation and examples --- .../generators/PythonApplicationGenerator.kt | 27 +++---- .../examples/pokemon_service.py | 7 +- .../src/error.rs | 65 +++++++++-------- .../src/logging.rs | 10 +-- .../src/middleware/handler.rs | 38 ++++++++-- .../src/middleware/layer.rs | 20 ++++-- .../src/middleware/mod.rs | 4 +- .../src/middleware/request.rs | 71 ++++++++++++------- .../src/middleware/response.rs | 18 +++++ .../src/server.rs | 3 +- 10 files changed, 172 insertions(+), 91 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 859e3e7703..941c2c620e 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -176,7 +176,7 @@ class PythonApplicationGenerator( self.middlewares.clone(), self.protocol(), middleware_locals - ), + )?, ); let router: #{SmithyServer}::Router = router .build() @@ -312,21 +312,17 @@ class PythonApplicationGenerator( """.trimIndent(), ) writer.rust( - if (operations.any { it.errors.isNotEmpty() }) { - """ - /// from $libName import ${Inputs.namespace} - /// from $libName import ${Outputs.namespace} - /// from $libName import ${Errors.namespace} - """.trimIndent() - } else { - """ - /// from $libName import ${Inputs.namespace} - /// from $libName import ${Outputs.namespace} - """.trimIndent() - }, + """ + /// from $libName import ${Inputs.namespace} + /// from $libName import ${Outputs.namespace} + """.trimIndent() ) + if (operations.any { it.errors.isNotEmpty() }) { + writer.rust("""/// from $libName import ${Errors.namespace}""".trimIndent()) + } writer.rust( """ + /// from $libName import middleware /// from $libName import App /// /// @dataclass @@ -336,6 +332,11 @@ class PythonApplicationGenerator( /// app = App() /// app.context(Context()) /// + /// @app.middleware + /// def middleware(request: middleware::Request): + /// if request.get_header("x-amzn-id") != "secret": + /// raise middleware.MiddlewareException("Unsupported `x-amz-id` header", 401) + /// """.trimIndent(), ) writer.operationImplementationStubs(operations) diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 8243417667..ac61390cee 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -116,7 +116,6 @@ def get_random_radio_stream(self) -> str: app = App() # Register the context. app.context(Context()) -# Register a middleware. ########################################################### @@ -129,7 +128,7 @@ def get_random_radio_stream(self) -> str: # changing the original request. # * Middleware returning a modified Request will update the original # request before continuing the execution. -# * Middleware returnign a Response will immediately terminate the request +# * Middleware returning a Response will immediately terminate the request # handling and return the response constructed from Python. # * Middleware raising MiddlewareException will immediately terminate the # request handling and return a protocol specific error, with the option of @@ -146,7 +145,7 @@ def check_content_type_header(request: Request): logging.warning(f"Invalid content type {content_type}, dumping headers: {request.headers()}") -# This middleware adds a new header called `x-amazon-answer` to the +# This middleware adds a new header called `x-amzn-answer` to the # request. We expect to see this header to be populated in the next # middleware. @app.middleware @@ -156,6 +155,8 @@ def add_x_amzn_answer_header(request: Request): return request +# This middleware checks if the header `x-amzn-answer` is correctly set +# to 42. @app.middleware async def check_x_amzn_answer_header(request: Request): # Check that `x-amzn-answer` is 42. diff --git a/rust-runtime/aws-smithy-http-server-python/src/error.rs b/rust-runtime/aws-smithy-http-server-python/src/error.rs index 9015449439..519e75501a 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/error.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/error.rs @@ -30,17 +30,23 @@ impl From for PyErr { } } +/// Exception that can be thrown from a Python middleware. +/// +/// It allows to specify a message and HTTP status code and implementing protocol specific capabilities +/// to build a [aws_smithy_http_server::response::Response] from it. #[pyclass(name = "MiddlewareException", extends = BasePyException)] +#[pyo3(text_signature = "(message, status_code)")] #[derive(Debug, Clone)] pub struct PyMiddlewareException { #[pyo3(get, set)] - pub message: String, + message: String, #[pyo3(get, set)] - pub status_code: u16, + status_code: u16, } #[pymethods] impl PyMiddlewareException { + /// Create a new [PyMiddlewareException]. #[new] fn newpy(message: String, status_code: Option) -> Self { Self { @@ -57,32 +63,8 @@ impl From for PyMiddlewareException { } impl PyMiddlewareException { - fn json_body(&self) -> String { - let mut out = String::new(); - let mut object = aws_smithy_json::serialize::JsonObjectWriter::new(&mut out); - object.key("message").string(self.message.as_str()); - object.finish(); - out - } - - fn xml_body(&self) -> String { - let mut out = String::new(); - { - let mut writer = aws_smithy_xml::encode::XmlWriter::new(&mut out); - let root = writer - .start_el("Error") - .write_ns("http://s3.amazonaws.com/doc/2006-03-01/", None); - let mut scope = root.finish(); - { - let mut inner_writer = scope.start_el("Message").finish(); - inner_writer.data(self.message.as_ref()); - } - scope.finish(); - } - out - } - - pub fn into_response(self, protocol: Protocol) -> Response { + /// Convert the exception into a [Response], following the [Protocol] specification. + pub(crate) fn into_response(self, protocol: Protocol) -> Response { let body = to_boxed(match protocol { Protocol::RestJson1 => self.json_body(), Protocol::RestXml => self.xml_body(), @@ -112,4 +94,31 @@ impl PyMiddlewareException { builder.body(body).expect("invalid HTTP response for `MiddlewareException`; please file a bug report under https://github.com/awslabs/smithy-rs/issues") } + + /// Serialize the body into a JSON object. + fn json_body(&self) -> String { + let mut out = String::new(); + let mut object = aws_smithy_json::serialize::JsonObjectWriter::new(&mut out); + object.key("message").string(self.message.as_str()); + object.finish(); + out + } + + /// Serialize the body into a XML object. + fn xml_body(&self) -> String { + let mut out = String::new(); + { + let mut writer = aws_smithy_xml::encode::XmlWriter::new(&mut out); + let root = writer + .start_el("Error") + .write_ns("http://s3.amazonaws.com/doc/2006-03-01/", None); + let mut scope = root.finish(); + { + let mut inner_writer = scope.start_el("Message").finish(); + inner_writer.data(self.message.as_ref()); + } + scope.finish(); + } + out + } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 0765860135..34d8f1f42c 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -4,7 +4,6 @@ */ //! Rust `tracing` and Python `logging` setup and utilities. - use std::path::PathBuf; use pyo3::prelude::*; @@ -39,6 +38,7 @@ pub fn setup_tracing(py: Python, logfile: Option<&PathBuf>) -> PyResult, @@ -103,13 +103,9 @@ fn setup_tracing_subscriber( /// Modifies the Python `logging` module to deliver its log messages using [tracing::Subscriber] events. /// /// To achieve this goal, the following changes are made to the module: -/// - A new builtin function `logging.python_tracing` transcodes `logging.LogRecord`s to `tracing::Event`s. This function +/// - A new builtin function `logging.py_tracing_event` transcodes `logging.LogRecord`s to `tracing::Event`s. This function /// is not exported in `logging.__all__`, as it is not intended to be called directly. -/// - A new class `logging.RustTracing` provides a `logging.Handler` that delivers all records to `python_tracing`. -/// - `logging.basicConfig` is changed to use `logging.HostHandler` by default. -/// -/// Since any call like `logging.warn(...)` sets up logging via `logging.basicConfig`, all log messages are now -/// delivered to `crate::logging`, which will send them to `tracing::event!`. +/// - A new class `logging.TracingHandler` provides a `logging.Handler` that delivers all records to `python_tracing`. #[pyclass(name = "TracingHandler")] #[derive(Debug, Clone)] pub struct PyTracingHandler; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index b6cc3da9f4..f3a334c2c8 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -15,6 +15,10 @@ use crate::{PyMiddlewareException, PyRequest, PyResponse}; use super::PyFuture; +/// A Python middleware handler function representation. +/// +/// The Python business logic implementation needs to carry some information +/// to be executed properly like if it is a coroutine. #[derive(Debug, Clone)] pub struct PyMiddlewareHandler { pub name: String, @@ -22,20 +26,27 @@ pub struct PyMiddlewareHandler { pub is_coroutine: bool, } +/// Structure holding the list of Python middlewares that will be executed by this server. +/// +/// Middlewares are executed one after each other inside the [crate::PyMiddlewareLayer] Tower layer. #[derive(Debug, Clone, Default)] -pub struct PyMiddlewares(pub Vec); +pub struct PyMiddlewares(Vec); impl PyMiddlewares { + /// Create a new instance of `PyMiddlewareHandlers` from a list of heandlers. pub fn new(handlers: Vec) -> Self { Self(handlers) } + /// Add a new handler to the list. pub fn push(&mut self, handler: PyMiddlewareHandler) { self.0.push(handler); } - // Our request handler. This is where we would implement the application logic - // for responding to HTTP requests... + /// Execute a single middleware handler. + /// + /// The handler is scheduled on the Python interpreter syncronously or asynchronously, + /// dependening on the handler signature. async fn execute_middleware( request: PyRequest, handler: PyMiddlewareHandler, @@ -83,6 +94,19 @@ impl PyMiddlewares { }) } + /// Execute all the available Python middlewares in order of registration. + /// + /// Once the response is returned by the Python interpreter, different scenarios can happen: + /// * Middleware not returning will let the execution continue to the next middleware without + /// changing the original request. + /// * Middleware returning a modified [PyRequest] will update the original request before + /// continuing the execution of the next middleware. + /// * Middleware returning a [PyResponse] will immediately terminate the request handling and + /// return the response constructed from Python. + /// * Middleware raising [PyMiddlewareException] will immediately terminate the request handling + /// and return a protocol specific error, with the option of setting the HTTP return code. + /// * Middleware raising any other exception will immediately terminate the request handling and + /// return a protocol specific error, with HTTP status code 500. pub fn run( &mut self, mut request: Request, @@ -106,13 +130,13 @@ impl PyMiddlewares { Ok((pyrequest, pyresponse)) => { if let Some(pyrequest) = pyrequest { if let Ok(headers) = (&pyrequest.headers).try_into() { - tracing::debug!("Middleware `{name}` returned an HTTP request, override headers with middleware's one"); + tracing::debug!("Python middleware `{name}` returned an HTTP request, override headers with middleware's one"); *request.headers_mut() = headers; } } if let Some(pyresponse) = pyresponse { tracing::debug!( - "Middleware `{name}` returned a HTTP response, exit middleware loop" + "Python middleware `{name}` returned a HTTP response, exit middleware loop" ); return Err(pyresponse.into()); } @@ -125,7 +149,9 @@ impl PyMiddlewares { } } } - tracing::debug!("Returning original request to operation handler"); + tracing::debug!( + "Python middleware execution finised, returning the request to operation handler" + ); Ok(request) }) } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index 149167a18c..6ae82d0807 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +//! Tower layer implementation of Python middleware handling. use std::{ pin::Pin, task::{Context, Poll}, @@ -15,10 +16,11 @@ use aws_smithy_http_server::{ use futures::{ready, Future}; use http::{Request, Response}; use pin_project_lite::pin_project; +use pyo3::PyResult; use pyo3_asyncio::TaskLocals; use tower::{Layer, Service}; -use crate::{middleware::PyFuture, PyMiddlewares}; +use crate::{error::PyException, middleware::PyFuture, PyMiddlewares}; #[derive(Debug, Clone)] pub struct PyMiddlewareLayer { @@ -28,19 +30,27 @@ pub struct PyMiddlewareLayer { } impl PyMiddlewareLayer { - pub fn new(handlers: PyMiddlewares, protocol: &str, locals: TaskLocals) -> PyMiddlewareLayer { + pub fn new( + handlers: PyMiddlewares, + protocol: &str, + locals: TaskLocals, + ) -> PyResult { let protocol = match protocol { "aws.protocols#restJson1" => Protocol::RestJson1, "aws.protocols#restXml" => Protocol::RestXml, "aws.protocols#awsjson10" => Protocol::AwsJson10, "aws.protocols#awsjson11" => Protocol::AwsJson11, - _ => panic!(), + _ => { + return Err(PyException::new_err(format!( + "Protocol {protocol} is not supported" + ))) + } }; - Self { + Ok(Self { handlers, protocol, locals, - } + }) } } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index 92b9eb8b36..2ba36387f0 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +//! Schedule pure Python middlewares as `Tower` layers. mod handler; mod layer; mod request; @@ -17,4 +18,5 @@ pub use self::layer::PyMiddlewareLayer; pub use self::request::{PyHttpVersion, PyRequest}; pub use self::response::PyResponse; -pub type PyFuture = BoxFuture<'static, Result, Response>>; +/// Future type returned by the Python middleware handler. +pub(crate) type PyFuture = BoxFuture<'static, Result, Response>>; diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index f3fcf83bcd..22dda22fd1 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -3,35 +3,58 @@ * SPDX-License-Identifier: Apache-2.0 */ +//! Python-compatible middleware [http::Request] implementation. use std::collections::HashMap; use aws_smithy_http_server::body::Body; use http::{Request, Version}; use pyo3::prelude::*; +/// Python compabible HTTP [Version]. #[pyclass(name = "HttpVersion")] #[derive(PartialEq, PartialOrd, Copy, Clone, Eq, Ord, Hash)] -pub enum PyHttpVersion { - Http09, - Http10, - Http11, - H2, - H3, - __NonExhaustive, +pub struct PyHttpVersion(Version); + +#[pymethods] +impl PyHttpVersion { + /// Extract the value of the HTTP [Version] into a string that + /// can be used by Python. + #[pyo3(text_signature = "($self)")] + fn value(&self) -> &str { + match self.0 { + Version::HTTP_09 => "HTTP/0.9", + Version::HTTP_10 => "HTTP/1.0", + Version::HTTP_11 => "HTTP/1.1", + Version::HTTP_2 => "HTTP/2.0", + Version::HTTP_3 => "HTTP/3.0", + _ => unreachable!(), + } + } } +/// Python-compatible [Request] object. +/// +/// For performance reasons, there is not support yet to pass the body to the Python middleware, +/// as it requires to consume and clone the body, which is a very expensive operation. +/// +/// TODO(if customers request for it, we can implemented an opt-in functionality to also pass +/// the body around). #[pyclass(name = "Request")] +#[pyo3(text_signature = "(request)")] #[derive(Debug, Clone)] pub struct PyRequest { #[pyo3(get, set)] method: String, #[pyo3(get, set)] uri: String, - pub headers: HashMap, + pub(crate) headers: HashMap, version: Version, } impl PyRequest { + /// Create a new Python-compatible [Request] structure from the Rust side. + /// + /// This is done by cloning the headers, method, URI and HTTP version to let them be owned by Python. pub fn new(request: &Request) -> Self { Self { method: request.method().to_string(), @@ -53,22 +76,14 @@ impl PyRequest { #[pymethods] impl PyRequest { #[new] + /// Create a new Python-compatible `Request` object from the Python side. fn newpy( method: String, uri: String, headers: Option>, version: Option, ) -> Self { - let version = version - .map(|v| match v { - PyHttpVersion::Http09 => Version::HTTP_09, - PyHttpVersion::Http10 => Version::HTTP_10, - PyHttpVersion::Http11 => Version::HTTP_11, - PyHttpVersion::H2 => Version::HTTP_2, - PyHttpVersion::H3 => Version::HTTP_3, - _ => unreachable!(), - }) - .unwrap_or(Version::HTTP_11); + let version = version.map(|v| v.0).unwrap_or(Version::HTTP_11); Self { method, uri, @@ -77,25 +92,27 @@ impl PyRequest { } } - fn version(&self) -> PyHttpVersion { - match self.version { - Version::HTTP_09 => PyHttpVersion::Http09, - Version::HTTP_10 => PyHttpVersion::Http10, - Version::HTTP_11 => PyHttpVersion::Http11, - Version::HTTP_2 => PyHttpVersion::H2, - Version::HTTP_3 => PyHttpVersion::H3, - _ => unreachable!(), - } + /// Return the HTTP version of this request. + #[pyo3(text_signature = "($self)")] + fn version(&self) -> String { + PyHttpVersion(self.version).value().to_string() } + /// Return the HTTP headers of this request. + /// TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?) + #[pyo3(text_signature = "($self)")] fn headers(&self) -> HashMap { self.headers.clone() } + /// Insert a new key/value into this request's headers. + #[pyo3(text_signature = "($self, key, value)")] fn set_header(&mut self, key: &str, value: &str) { self.headers.insert(key.to_string(), value.to_string()); } + /// Return a header value of this request. + #[pyo3(text_signature = "($self, key)")] fn get_header(&self, key: &str) -> Option<&String> { self.headers.get(key) } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs index 786aa949e9..773fe76327 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/response.rs @@ -3,13 +3,22 @@ * SPDX-License-Identifier: Apache-2.0 */ +//! Python-compatible middleware [http::Request] implementation. use std::{collections::HashMap, convert::TryInto}; use aws_smithy_http_server::body::{to_boxed, BoxBody}; use http::{Response, StatusCode}; use pyo3::prelude::*; +/// Python-compatible [Response] object. +/// +/// For performance reasons, there is not support yet to pass the body to the Python middleware, +/// as it requires to consume and clone the body, which is a very expensive operation. +/// +// TODO(if customers request for it, we can implemented an opt-in functionality to also pass +// the body around). #[pyclass(name = "Response")] +#[pyo3(text_signature = "(status, headers, body)")] #[derive(Debug, Clone)] pub struct PyResponse { #[pyo3(get, set)] @@ -21,6 +30,7 @@ pub struct PyResponse { #[pymethods] impl PyResponse { + /// Python-compatible [Response] object from the Python side. #[new] fn newpy(status: u16, headers: Option>, body: Option>) -> Self { Self { @@ -30,19 +40,27 @@ impl PyResponse { } } + /// Return the HTTP headers of this response. + // TODO(can we use `Py::clone_ref()` to prevent cloning the hashmap?) + #[pyo3(text_signature = "($self)")] fn headers(&self) -> HashMap { self.headers.clone() } + /// Insert a new key/value into this response's headers. + #[pyo3(text_signature = "($self, key, value)")] fn set_header(&mut self, key: &str, value: &str) { self.headers.insert(key.to_string(), value.to_string()); } + /// Return a header value of this response. + #[pyo3(text_signature = "($self, key)")] fn get_header(&self, key: &str) -> Option<&String> { self.headers.get(key) } } +/// Allow to convert between a [PyResponse] and a [Response]. impl From for Response { fn from(pyresponse: PyResponse) -> Self { let mut response = Response::builder() diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index b348700506..09c7e14588 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -258,9 +258,10 @@ event_loop.add_signal_handler(signal.SIGINT, Ok(()) } + // Check if a Python function is a coroutine. Since the function has not run yet, + // we cannot use `asyncio.iscoroutine()`, we need to use `inspect.iscoroutinefunction()`. fn is_coroutine(&self, py: Python, func: &PyObject) -> PyResult { let inspect = py.import("inspect")?; - // Check if the function is a coroutine. // NOTE: that `asyncio.iscoroutine()` doesn't work here. inspect .call_method1("iscoroutinefunction", (func,))? From 1cddc954e2e7c70eda18936b762dd5e1766aa342 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Tue, 20 Sep 2022 16:23:26 +0100 Subject: [PATCH 23/30] Fix test --- .../aws-smithy-http-server-python/src/middleware/layer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index 6ae82d0807..d941082b5c 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -185,7 +185,7 @@ mod tests { #[tokio::test] async fn test_middlewares_are_chained_inside_layer() -> PyResult<()> { let locals = crate::tests::initialize(); - let mut middlewares = PyMiddlewares(vec![]); + let mut middlewares = PyMiddlewares::new(vec![]); Python::with_gil(|py| { let middleware = PyModule::new(py, "middleware").unwrap(); @@ -224,7 +224,7 @@ def second_middleware(request: Request): middlewares, "aws.protocols#restJson1", locals, - )) + )?) .service_fn(echo); let request = Request::get("/").body(Body::empty()).unwrap(); From 324e30f798034d496a2f220bc82a5dc30a8250ec Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Tue, 20 Sep 2022 16:30:50 +0100 Subject: [PATCH 24/30] Fix kotlin linting --- .../python/smithy/generators/PythonApplicationGenerator.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 941c2c620e..29310ccff6 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -315,7 +315,7 @@ class PythonApplicationGenerator( """ /// from $libName import ${Inputs.namespace} /// from $libName import ${Outputs.namespace} - """.trimIndent() + """.trimIndent(), ) if (operations.any { it.errors.isNotEmpty() }) { writer.rust("""/// from $libName import ${Errors.namespace}""".trimIndent()) From 34c55f103277ded72c4466524def7377cebbf98a Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Tue, 20 Sep 2022 18:10:49 +0100 Subject: [PATCH 25/30] Reword middleware to explicitly tell we only support requests so far --- .../generators/PythonApplicationGenerator.kt | 10 +++++----- .../examples/pokemon_service.py | 6 +++--- .../aws-smithy-http-server-python/src/lib.rs | 4 +++- .../src/middleware/handler.rs | 20 +++++++++++++++---- .../src/middleware/layer.rs | 15 ++++++++++++-- .../src/middleware/mod.rs | 3 ++- .../src/server.rs | 14 +++++++++---- 7 files changed, 52 insertions(+), 20 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index 29310ccff6..c6190b67de 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -242,11 +242,11 @@ class PythonApplicationGenerator( pub fn context(&mut self, context: #{pyo3}::PyObject) { self.context = Some(context); } - /// Register a middleware function that will be run inside a Tower layer, without cloning the body. + /// Register a request middleware function that will be run inside a Tower layer, without cloning the body. ##[pyo3(text_signature = "(${'$'}self, func)")] - pub fn middleware(&mut self, py: pyo3::Python, func: pyo3::PyObject) -> pyo3::PyResult<()> { + pub fn request_middleware(&mut self, py: #{pyo3}::Python, func: #{pyo3}::PyObject) -> #{pyo3}::PyResult<()> { use #{SmithyPython}::PyApp; - self.register_middleware(py, func) + self.register_middleware(py, func, #{SmithyPython}::PyMiddlewareType::Request) } /// Main entrypoint: start the server on multiple workers. ##[pyo3(text_signature = "(${'$'}self, address, port, backlog, workers)")] @@ -332,8 +332,8 @@ class PythonApplicationGenerator( /// app = App() /// app.context(Context()) /// - /// @app.middleware - /// def middleware(request: middleware::Request): + /// @app.request_middleware + /// def request_middleware(request: middleware::Request): /// if request.get_header("x-amzn-id") != "secret": /// raise middleware.MiddlewareException("Unsupported `x-amz-id` header", 401) /// diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index ac61390cee..dba8605e1f 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -136,7 +136,7 @@ def get_random_radio_stream(self) -> str: # * Middleware raising any other exception will immediately terminate the # request handling and return a protocol specific error, with HTTP status # code 500. -@app.middleware +@app.request_middleware def check_content_type_header(request: Request): content_type = request.get_header("content-type") if content_type == "application/json": @@ -148,7 +148,7 @@ def check_content_type_header(request: Request): # This middleware adds a new header called `x-amzn-answer` to the # request. We expect to see this header to be populated in the next # middleware. -@app.middleware +@app.request_middleware def add_x_amzn_answer_header(request: Request): request.set_header("x-amzn-answer", "42") logging.debug("Setting `x-amzn-answer` header to 42") @@ -157,7 +157,7 @@ def add_x_amzn_answer_header(request: Request): # This middleware checks if the header `x-amzn-answer` is correctly set # to 42. -@app.middleware +@app.request_middleware async def check_x_amzn_answer_header(request: Request): # Check that `x-amzn-answer` is 42. if request.get_header("x-amzn-answer") != "42": diff --git a/rust-runtime/aws-smithy-http-server-python/src/lib.rs b/rust-runtime/aws-smithy-http-server-python/src/lib.rs index c5b4f838cb..793104af59 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/lib.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/lib.rs @@ -23,7 +23,9 @@ pub use error::{PyError, PyMiddlewareException}; #[doc(inline)] pub use logging::{py_tracing_event, PyTracingHandler}; #[doc(inline)] -pub use middleware::{PyHttpVersion, PyMiddlewareLayer, PyMiddlewares, PyRequest, PyResponse}; +pub use middleware::{ + PyHttpVersion, PyMiddlewareLayer, PyMiddlewareType, PyMiddlewares, PyRequest, PyResponse, +}; #[doc(inline)] pub use server::{PyApp, PyHandler}; #[doc(inline)] diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs index f3a334c2c8..00f99312e8 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/handler.rs @@ -15,6 +15,12 @@ use crate::{PyMiddlewareException, PyRequest, PyResponse}; use super::PyFuture; +#[derive(Debug, Clone, Copy)] +pub enum PyMiddlewareType { + Request, + Response, +} + /// A Python middleware handler function representation. /// /// The Python business logic implementation needs to carry some information @@ -24,6 +30,7 @@ pub struct PyMiddlewareHandler { pub name: String, pub func: PyObject, pub is_coroutine: bool, + pub _type: PyMiddlewareType, } /// Structure holding the list of Python middlewares that will be executed by this server. @@ -166,7 +173,7 @@ mod tests { use super::*; #[tokio::test] - async fn middleware_chain_keeps_headers_changes() -> PyResult<()> { + async fn request_middleware_chain_keeps_headers_changes() -> PyResult<()> { let locals = crate::tests::initialize(); let mut middlewares = PyMiddlewares(vec![]); @@ -189,6 +196,7 @@ def second_middleware(request: Request): func: middleware.getattr("first_middleware")?.into_py(py), is_coroutine: false, name: "first".to_string(), + _type: PyMiddlewareType::Request, }; all.append("first_middleware")?; middlewares.push(first_middleware); @@ -196,6 +204,7 @@ def second_middleware(request: Request): func: middleware.getattr("second_middleware")?.into_py(py), is_coroutine: false, name: "second".to_string(), + _type: PyMiddlewareType::Request, }; all.append("second_middleware")?; middlewares.push(second_middleware); @@ -218,7 +227,7 @@ def second_middleware(request: Request): } #[tokio::test] - async fn middleware_return_response() -> PyResult<()> { + async fn request_middleware_return_response() -> PyResult<()> { let locals = crate::tests::initialize(); let mut middlewares = PyMiddlewares(vec![]); @@ -235,6 +244,7 @@ def middleware(request: Request): func: middleware.getattr("middleware")?.into_py(py), is_coroutine: false, name: "middleware".to_string(), + _type: PyMiddlewareType::Request, }; all.append("middleware")?; middlewares.push(middleware); @@ -256,7 +266,7 @@ def middleware(request: Request): } #[tokio::test] - async fn middleware_raise_middleware_exception() -> PyResult<()> { + async fn request_middleware_raise_middleware_exception() -> PyResult<()> { let locals = crate::tests::initialize(); let mut middlewares = PyMiddlewares(vec![]); @@ -273,6 +283,7 @@ def middleware(request: Request): func: middleware.getattr("middleware")?.into_py(py), is_coroutine: false, name: "middleware".to_string(), + _type: PyMiddlewareType::Request, }; all.append("middleware")?; middlewares.push(middleware); @@ -298,7 +309,7 @@ def middleware(request: Request): } #[tokio::test] - async fn middleware_raise_python_exception() -> PyResult<()> { + async fn request_middleware_raise_python_exception() -> PyResult<()> { let locals = crate::tests::initialize(); let mut middlewares = PyMiddlewares(vec![]); @@ -315,6 +326,7 @@ def middleware(request): func: middleware.getattr("middleware")?.into_py(py), is_coroutine: false, name: "middleware".to_string(), + _type: PyMiddlewareType::Request, }; middlewares.push(middleware); Ok::<(), PyErr>(()) diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index d941082b5c..b371fbad8b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -22,6 +22,10 @@ use tower::{Layer, Service}; use crate::{error::PyException, middleware::PyFuture, PyMiddlewares}; +/// Tower [Layer] implementation of Python middleware handling. +/// +/// Middleware stored in the `handlers` attribute will be executed, in order, +/// inside an async Tower middleware. #[derive(Debug, Clone)] pub struct PyMiddlewareLayer { handlers: PyMiddlewares, @@ -67,6 +71,7 @@ impl Layer for PyMiddlewareLayer { } } +// Tower [Service] wrapping the Python middleware [Layer]. #[derive(Clone, Debug)] pub struct PyMiddlewareService { inner: S, @@ -115,6 +120,7 @@ where } pin_project! { + /// Response future handling the state transition between a running and a done future. pub struct ResponseFuture where S: Service>, @@ -126,6 +132,7 @@ pin_project! { } pin_project! { + /// Representation of the result of the middleware execution. #[project = StateProj] enum State { Running { @@ -149,6 +156,7 @@ where let mut this = self.project(); loop { match this.middleware.as_mut().project() { + // Run the handler and store the future inside the inner state. StateProj::Running { run } => { let run = ready!(run.poll(cx)); match run { @@ -159,6 +167,7 @@ where Err(res) => return Poll::Ready(Ok(res)), } } + // Execute the future returned by the layer. StateProj::Done { fut } => return fut.poll(cx), } } @@ -176,14 +185,14 @@ mod tests { use tower::{Service, ServiceBuilder, ServiceExt}; use crate::middleware::PyMiddlewareHandler; - use crate::{PyMiddlewareException, PyRequest}; + use crate::{PyMiddlewareException, PyMiddlewareType, PyRequest}; async fn echo(req: Request) -> Result, Box> { Ok(Response::new(to_boxed(req.into_body()))) } #[tokio::test] - async fn test_middlewares_are_chained_inside_layer() -> PyResult<()> { + async fn request_middlewares_are_chained_inside_layer() -> PyResult<()> { let locals = crate::tests::initialize(); let mut middlewares = PyMiddlewares::new(vec![]); @@ -206,6 +215,7 @@ def second_middleware(request: Request): func: middleware.getattr("first_middleware")?.into_py(py), is_coroutine: false, name: "first".to_string(), + _type: PyMiddlewareType::Request, }; all.append("first_middleware")?; middlewares.push(first_middleware); @@ -213,6 +223,7 @@ def second_middleware(request: Request): func: middleware.getattr("second_middleware")?.into_py(py), is_coroutine: false, name: "second".to_string(), + _type: PyMiddlewareType::Request, }; all.append("second_middleware")?; middlewares.push(second_middleware); diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs index 2ba36387f0..a1a2d14ced 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/mod.rs @@ -13,10 +13,11 @@ use aws_smithy_http_server::body::{Body, BoxBody}; use futures::future::BoxFuture; use http::{Request, Response}; -pub use self::handler::{PyMiddlewareHandler, PyMiddlewares}; +pub use self::handler::{PyMiddlewareType, PyMiddlewares}; pub use self::layer::PyMiddlewareLayer; pub use self::request::{PyHttpVersion, PyRequest}; pub use self::response::PyResponse; +pub(crate) use self::handler::PyMiddlewareHandler; /// Future type returned by the Python middleware handler. pub(crate) type PyFuture = BoxFuture<'static, Result, Response>>; diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index 09c7e14588..bf86bf4c48 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -13,7 +13,7 @@ use signal_hook::{consts::*, iterator::Signals}; use tokio::runtime; use tower::ServiceBuilder; -use crate::{middleware::PyMiddlewareHandler, PyMiddlewares, PySocket}; +use crate::{middleware::PyMiddlewareHandler, PyMiddlewareType, PyMiddlewares, PySocket}; /// A Python handler function representation. /// @@ -272,7 +272,12 @@ event_loop.add_signal_handler(signal.SIGINT, /// /// There are some information needed to execute the Python code from a Rust handler, /// such has if the registered function needs to be awaited (if it is a coroutine).. - fn register_middleware(&mut self, py: Python, func: PyObject) -> PyResult<()> { + fn register_middleware( + &mut self, + py: Python, + func: PyObject, + _type: PyMiddlewareType, + ) -> PyResult<()> { let name = func.getattr(py, "__name__")?.extract::(py)?; let is_coroutine = self.is_coroutine(py, &func)?; // Find number of expected methods (a Python implementation could not accept the context). @@ -280,10 +285,11 @@ event_loop.add_signal_handler(signal.SIGINT, name, func, is_coroutine, + _type, }; println!( - "Registering middleware function `{}`, coroutine: {}", - handler.name, handler.is_coroutine, + "Registering middleware function `{}`, coroutine: {}, type: {:?}", + handler.name, handler.is_coroutine, handler._type ); self.middlewares().push(handler); Ok(()) From fb58bdfbaf82e607113703fea82da8ee14647e76 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Tue, 20 Sep 2022 18:24:36 +0100 Subject: [PATCH 26/30] Update changelog --- CHANGELOG.next.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index fc4c4c2578..496c781d01 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -9,4 +9,9 @@ # message = "Fix typos in module documentation for generated crates" # references = ["smithy-rs#920"] # meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "client | server | all"} -# author = "rcoh" \ No newline at end of file +# author = "rcoh" +[[smithy-rs]] +message = "Implement support for pure Python request middleware. Improve idiomatic logging support over tracing." +references = ["smithy-rs#1734"] +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server"} +author = "crisidev" From 8f98141e1dd3b520abc10d8af5562a13e9681be7 Mon Sep 17 00:00:00 2001 From: Matteo Bigoi <1781140+crisidev@users.noreply.github.com> Date: Wed, 21 Sep 2022 15:42:39 +0100 Subject: [PATCH 27/30] Apply suggestions from code review Co-authored-by: Burak --- .../python/smithy/generators/PythonServerModuleGenerator.kt | 2 +- rust-runtime/aws-smithy-http-server-python/src/logging.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt index 09fb6ec25a..654e7c10e0 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerModuleGenerator.kt @@ -156,7 +156,7 @@ class PythonServerModuleGenerator( pyo3::py_run!( py, middleware, - "import sys; sys.modules['libpokemon_service_server_sdk.middleware'] = middleware" + "import sys; sys.modules['$libName.middleware'] = middleware" ); m.add_submodule(middleware)?; """, diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 34d8f1f42c..47b58324bc 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -160,7 +160,7 @@ pub fn py_tracing_event( 30 => tracing::warn!("{message}"), 20 => tracing::info!("{message}"), 10 => tracing::debug!("{message}"), - _ => tracing::event!(Level::TRACE, %pid, %module, %filename, %line, "{message}"), + _ => tracing::trace!("{message}"), }; Ok(()) } From 4ee2743d809d224b41e958f2fb16eb59b852e446 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Wed, 21 Sep 2022 18:56:16 +0100 Subject: [PATCH 28/30] Refactor logging for a more idiomatic experience --- CHANGELOG.next.toml | 8 +- .../generators/PythonApplicationGenerator.kt | 16 +--- .../PythonServerOperationHandlerGenerator.kt | 9 -- .../examples/pokemon_service.py | 10 +- .../src/logging.rs | 93 ++++++++++--------- .../src/middleware/layer.rs | 5 +- .../src/middleware/request.rs | 4 + .../src/server.rs | 31 ++++--- 8 files changed, 90 insertions(+), 86 deletions(-) diff --git a/CHANGELOG.next.toml b/CHANGELOG.next.toml index ff2f4f89bc..49cdd29e26 100644 --- a/CHANGELOG.next.toml +++ b/CHANGELOG.next.toml @@ -13,5 +13,11 @@ [[rust-runtime]] message = "Pokémon Service example code now runs clippy during build." references = ["smithy-rs#1727"] -meta = { "breaking" = false, "tada" = false, "bug" = false } +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" } author = "GeneralSwiss" + +[[smithy-rs]] +message = "Implement support for pure Python request middleware. Improve idiomatic logging support over tracing." +references = ["smithy-rs#1734"] +meta = { "breaking" = false, "tada" = false, "bug" = false, "target" = "server" } +author = "crisidev" diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt index c6190b67de..a1d7d818c2 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonApplicationGenerator.kt @@ -77,7 +77,6 @@ class PythonApplicationGenerator( "pyo3_asyncio" to PythonServerCargoDependency.PyO3Asyncio.asType(), "tokio" to PythonServerCargoDependency.Tokio.asType(), "tracing" to PythonServerCargoDependency.Tracing.asType(), - "tracing_appender" to PythonServerCargoDependency.TracingAppender.asType(), "tower" to PythonServerCargoDependency.Tower.asType(), "tower_http" to PythonServerCargoDependency.TowerHttp.asType(), "num_cpus" to PythonServerCargoDependency.NumCpus.asType(), @@ -105,8 +104,6 @@ class PythonApplicationGenerator( middlewares: #{SmithyPython}::PyMiddlewares, context: Option<#{pyo3}::PyObject>, workers: #{parking_lot}::Mutex>, - _tracing_guard: Option<#{tracing_appender}::non_blocking::WorkerGuard>, - logfile: Option } """, *codegenScope, @@ -123,8 +120,6 @@ class PythonApplicationGenerator( middlewares: self.middlewares.clone(), context: self.context.clone(), workers: #{parking_lot}::Mutex::new(vec![]), - _tracing_guard: None, - logfile: self.logfile.clone() } } } @@ -228,14 +223,8 @@ class PythonApplicationGenerator( """ /// Create a new [App]. ##[new] - pub fn new(logfile: Option<&#{pyo3}::PyAny>) -> #{pyo3}::PyResult { - let logfile = if let Some(logfile) = logfile { - let logfile = logfile.extract::<&str>()?; - Some(std::path::Path::new(logfile).to_path_buf()) - } else { - None - }; - Ok(Self { logfile, ..Default::default() }) + pub fn new() -> Self { + Self::default() } /// Register a context object that will be shared between handlers. ##[pyo3(text_signature = "(${'$'}self, context)")] @@ -270,7 +259,6 @@ class PythonApplicationGenerator( worker_number: isize, ) -> pyo3::PyResult<()> { use #{SmithyPython}::PyApp; - self._tracing_guard = #{SmithyPython}::logging::setup_tracing(py, self.logfile.as_ref())?; let event_loop = self.configure_python_event_loop(py)?; let router = self.build_router(event_loop)?; self.start_hyper_worker(py, socket, event_loop, router, worker_number) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt index 709df4434b..a158d67722 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/generators/PythonServerOperationHandlerGenerator.kt @@ -69,20 +69,11 @@ class PythonServerOperationHandlerGenerator( handler: #{SmithyPython}::PyHandler, ) -> std::result::Result<$output, $error> { // Async block used to run the handler and catch any Python error. - let span = #{tracing}::span!( - #{tracing}::Level::TRACE, "python", - pid = #{tracing}::field::Empty, - module = #{tracing}::field::Empty, - filename = #{tracing}::field::Empty, - lineno = #{tracing}::field::Empty - ); - let guard = span.enter(); let result = if handler.is_coroutine { #{PyCoroutine:W} } else { #{PyFunction:W} }; - drop(guard); #{PyError:W} } """, diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index dba8605e1f..12d81f875c 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -11,6 +11,7 @@ from typing import List, Optional import aiohttp + from libpokemon_service_server_sdk import App from libpokemon_service_server_sdk.error import ResourceNotFoundException from libpokemon_service_server_sdk.input import ( @@ -27,7 +28,7 @@ # Logging can bee setup using standard Python tooling. We provide # fast logging handler, Tracingandler based on Rust tracing crate. -logging.basicConfig(level=logging.INFO, handlers=[TracingHandler.handle()]) +logging.basicConfig(handlers=[TracingHandler(level=logging.DEBUG).handler()]) # A slightly more atomic counter using a threading lock. @@ -142,7 +143,9 @@ def check_content_type_header(request: Request): if content_type == "application/json": logging.debug("Found valid `application/json` content type") else: - logging.warning(f"Invalid content type {content_type}, dumping headers: {request.headers()}") + logging.warning( + f"Invalid content type {content_type}, dumping headers: {request.headers()}" + ) # This middleware adds a new header called `x-amzn-answer` to the @@ -183,8 +186,9 @@ def get_pokemon_species( context.increment_calls_count() flavor_text_entries = context.get_pokemon_description(input.name) if flavor_text_entries: - logging.error("Total requests executed: %s", context.get_calls_count()) + logging.debug("Total requests executed: %s", context.get_calls_count()) logging.info("Found description for Pokémon %s", input.name) + logging.error("Found some stuff") return GetPokemonSpeciesOutput( name=input.name, flavor_text_entries=flavor_text_entries ) diff --git a/rust-runtime/aws-smithy-http-server-python/src/logging.rs b/rust-runtime/aws-smithy-http-server-python/src/logging.rs index 47b58324bc..c7894d0c5e 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/logging.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/logging.rs @@ -7,9 +7,9 @@ use std::path::PathBuf; use pyo3::prelude::*; -use tracing::Level; #[cfg(not(test))] -use tracing::Span; +use tracing::span; +use tracing::Level; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{ fmt::{self, writer::MakeWriterExt}, @@ -19,29 +19,10 @@ use tracing_subscriber::{ use crate::error::PyException; -/// Setup `tracing::subscriber` reading the log level from RUST_LOG environment variable. -/// If the variable is not set, the logging for both Python and Rust will be set at the -/// level used by Python logging module. -pub fn setup_tracing(py: Python, logfile: Option<&PathBuf>) -> PyResult> { - let logging = py.import("logging")?; - let root = logging.getattr("root")?; - let handlers = root.getattr("handlers")?; - let handlers = handlers.extract::>()?; - for handler in handlers.iter() { - let name = handler.getattr(py, "__name__")?; - if let Ok(name) = name.extract::<&str>(py) { - if name == "SmithyRsTracingHandler" { - return setup_tracing_subscriber(py, logfile); - } - } - } - Ok(None) -} - /// Setup tracing-subscriber to log on console or to a hourly rolling file. fn setup_tracing_subscriber( - py: Python, - logfile: Option<&PathBuf>, + level: Option, + logfile: Option, ) -> PyResult> { let appender = match logfile { Some(logfile) => { @@ -64,21 +45,20 @@ fn setup_tracing_subscriber( None => None, }; - let logging = py.import("logging")?; - let root = logging.getattr("root")?; - let level: u8 = root.getattr("level")?.extract()?; - let level = match level { - 40u8 => Level::ERROR, - 30u8 => Level::WARN, - 20u8 => Level::INFO, - 10u8 => Level::DEBUG, + let tracing_level = match level { + Some(40u8) => Level::ERROR, + Some(30u8) => Level::WARN, + Some(20u8) => Level::INFO, + Some(10u8) => Level::DEBUG, + None => Level::INFO, _ => Level::TRACE, }; + match appender { Some((appender, guard)) => { let layer = Some( fmt::Layer::new() - .with_writer(appender.with_max_level(level)) + .with_writer(appender.with_max_level(tracing_level)) .with_ansi(true) .with_line_number(true) .with_level(true), @@ -89,7 +69,7 @@ fn setup_tracing_subscriber( None => { let layer = Some( fmt::Layer::new() - .with_writer(std::io::stdout.with_max_level(level)) + .with_writer(std::io::stdout.with_max_level(tracing_level)) .with_ansi(true) .with_line_number(true) .with_level(true), @@ -107,13 +87,24 @@ fn setup_tracing_subscriber( /// is not exported in `logging.__all__`, as it is not intended to be called directly. /// - A new class `logging.TracingHandler` provides a `logging.Handler` that delivers all records to `python_tracing`. #[pyclass(name = "TracingHandler")] -#[derive(Debug, Clone)] -pub struct PyTracingHandler; +#[derive(Debug)] +pub struct PyTracingHandler { + _guard: Option, +} #[pymethods] impl PyTracingHandler { - #[staticmethod] - fn handle(py: Python) -> PyResult> { + #[new] + fn newpy(py: Python, level: Option, logfile: Option) -> PyResult { + let _guard = setup_tracing_subscriber(level, logfile)?; + let logging = py.import("logging")?; + let root = logging.getattr("root")?; + root.setattr("level", level)?; + // TODO(Investigate why the file appender just create the file and does not write anything, event after holding the guard) + Ok(Self { _guard }) + } + + fn handler(&self, py: Python) -> PyResult> { let logging = py.import("logging")?; logging.setattr( "py_tracing_event", @@ -122,7 +113,6 @@ impl PyTracingHandler { let pycode = r#" class TracingHandler(Handler): - __name__ = "SmithyRsTracingHandler" """ Python logging to Rust tracing handler. """ def emit(self, record): py_tracing_event( @@ -147,14 +137,19 @@ pub fn py_tracing_event( message: &str, module: &str, filename: &str, - line: usize, + lineno: usize, pid: usize, ) -> PyResult<()> { - let span = Span::current(); - span.record("pid", pid); - span.record("module", module); - span.record("filename", filename); - span.record("lineno", line); + let span = span!( + Level::TRACE, + "python", + pid = pid, + module = module, + filename = filename, + lineno = lineno + ); + println!("message2: {message}"); + let _guard = span.enter(); match level { 40 => tracing::error!("{message}"), 30 => tracing::warn!("{message}"), @@ -182,14 +177,22 @@ pub fn py_tracing_event( #[cfg(test)] mod tests { + use pyo3::types::PyDict; + use super::*; #[test] fn tracing_handler_is_injected_in_python() { crate::tests::initialize(); Python::with_gil(|py| { - setup_tracing(py, None).unwrap(); + let handler = PyTracingHandler::newpy(py, Some(10), None).unwrap(); + let kwargs = PyDict::new(py); + kwargs + .set_item("handlers", vec![handler.handler(py).unwrap()]) + .unwrap(); let logging = py.import("logging").unwrap(); + let basic_config = logging.getattr("basicConfig").unwrap(); + basic_config.call((), Some(kwargs)).unwrap(); logging.call_method1("info", ("a message",)).unwrap(); }); } diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs index b371fbad8b..73508541a2 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/layer.rs @@ -109,7 +109,10 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let inner = self.inner.clone(); + // TODO(Should we make this clone less expensive by wrapping inner in a Arc?) + let clone = self.inner.clone(); + // See https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + let inner = std::mem::replace(&mut self.inner, clone); let run = self.handlers.run(req, self.protocol, self.locals.clone()); ResponseFuture { diff --git a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs index 22dda22fd1..467d7dbb7b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/middleware/request.rs @@ -47,6 +47,10 @@ pub struct PyRequest { method: String, #[pyo3(get, set)] uri: String, + // TODO(investigate if using a PyDict can make the experience more idiomatic) + // I'd like to be able to do request.headers.get("my-header") and + // request.headers["my-header"] = 42 instead of implementing set_header() and get_header() + // under pymethods. The same applies to response. pub(crate) headers: HashMap, version: Version, } diff --git a/rust-runtime/aws-smithy-http-server-python/src/server.rs b/rust-runtime/aws-smithy-http-server-python/src/server.rs index bf86bf4c48..f7d48b702b 100644 --- a/rust-runtime/aws-smithy-http-server-python/src/server.rs +++ b/rust-runtime/aws-smithy-http-server-python/src/server.rs @@ -77,15 +77,17 @@ pub trait PyApp: Clone + pyo3::IntoPy { .getattr(py, "pid") .map(|pid| pid.extract(py).unwrap_or(-1)) .unwrap_or(-1); - println!("Terminating worker {idx}, PID: {pid}"); + tracing::debug!("Terminating worker {idx}, PID: {pid}"); match worker.call_method0(py, "terminate") { Ok(_) => {} Err(e) => { - eprintln!("Error terminating worker {idx}, PID: {pid}: {e}"); + tracing::error!("Error terminating worker {idx}, PID: {pid}: {e}"); worker .call_method0(py, "kill") .map_err(|e| { - eprintln!("Unable to kill kill worker {idx}, PID: {pid}: {e}"); + tracing::error!( + "Unable to kill kill worker {idx}, PID: {pid}: {e}" + ); }) .unwrap(); } @@ -106,11 +108,11 @@ pub trait PyApp: Clone + pyo3::IntoPy { .getattr(py, "pid") .map(|pid| pid.extract(py).unwrap_or(-1)) .unwrap_or(-1); - println!("Killing worker {idx}, PID: {pid}"); + tracing::debug!("Killing worker {idx}, PID: {pid}"); worker .call_method0(py, "kill") .map_err(|e| { - eprintln!("Unable to kill kill worker {idx}, PID: {pid}: {e}"); + tracing::error!("Unable to kill kill worker {idx}, PID: {pid}: {e}"); }) .unwrap(); }); @@ -133,20 +135,20 @@ pub trait PyApp: Clone + pyo3::IntoPy { for sig in signals.forever() { match sig { SIGINT => { - println!( + tracing::info!( "Termination signal {sig:?} received, all workers will be immediately terminated" ); self.immediate_termination(self.workers()); } SIGTERM | SIGQUIT => { - println!( + tracing::info!( "Termination signal {sig:?} received, all workers will be gracefully terminated" ); self.graceful_termination(self.workers()); } _ => { - println!("Signal {sig:?} is ignored by this application"); + tracing::debug!("Signal {sig:?} is ignored by this application"); } } } @@ -287,9 +289,11 @@ event_loop.add_signal_handler(signal.SIGINT, is_coroutine, _type, }; - println!( + tracing::info!( "Registering middleware function `{}`, coroutine: {}, type: {:?}", - handler.name, handler.is_coroutine, handler._type + handler.name, + handler.is_coroutine, + handler._type ); self.middlewares().push(handler); Ok(()) @@ -314,9 +318,10 @@ event_loop.add_signal_handler(signal.SIGINT, is_coroutine, args: func_args.len(), }; - println!( + tracing::info!( "Registering handler function `{name}`, coroutine: {}, arguments: {}", - handler.is_coroutine, handler.args, + handler.is_coroutine, + handler.args, ); // Insert the handler in the handlers map. self.handlers().insert(name.to_string(), handler); @@ -445,7 +450,7 @@ event_loop.add_signal_handler(signal.SIGINT, } // Unlock the workers mutex. drop(active_workers); - println!("Rust Python server started successfully"); + tracing::info!("Rust Python server started successfully"); self.block_on_rust_signals(); Ok(()) } From 433c162d8bfc6017abe34a1c2621f7bf61a27eca Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Thu, 22 Sep 2022 14:57:06 +0100 Subject: [PATCH 29/30] Remove useless dependency --- .../codegen/server/python/smithy/PythonServerCargoDependency.kt | 1 - 1 file changed, 1 deletion(-) diff --git a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt index 2c37ceb14e..d92482f324 100644 --- a/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt +++ b/codegen-server/python/src/main/kotlin/software/amazon/smithy/rust/codegen/server/python/smithy/PythonServerCargoDependency.kt @@ -19,7 +19,6 @@ object PythonServerCargoDependency { val PyO3Asyncio: CargoDependency = CargoDependency("pyo3-asyncio", CratesIo("0.16"), features = setOf("attributes", "tokio-runtime")) val Tokio: CargoDependency = CargoDependency("tokio", CratesIo("1.20.1"), features = setOf("full")) val Tracing: CargoDependency = CargoDependency("tracing", CratesIo("0.1")) - val TracingAppender: CargoDependency = CargoDependency("tracing-appender", CratesIo("0.2")) val Tower: CargoDependency = CargoDependency("tower", CratesIo("0.4")) val TowerHttp: CargoDependency = CargoDependency("tower-http", CratesIo("0.3"), features = setOf("trace")) val Hyper: CargoDependency = CargoDependency("hyper", CratesIo("0.14.12"), features = setOf("server", "http1", "http2", "tcp", "stream")) From da065a4f49b3998fb0affcc38e28355e6afc6c54 Mon Sep 17 00:00:00 2001 From: Bigo <1781140+crisidev@users.noreply.github.com> Date: Thu, 22 Sep 2022 17:55:00 +0100 Subject: [PATCH 30/30] Fix documentation --- .../aws-smithy-http-server-python/examples/pokemon_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py index 12d81f875c..67345fb629 100644 --- a/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py +++ b/rust-runtime/aws-smithy-http-server-python/examples/pokemon_service.py @@ -159,7 +159,7 @@ def add_x_amzn_answer_header(request: Request): # This middleware checks if the header `x-amzn-answer` is correctly set -# to 42. +# to 42, otherwise it returns an exception with a set status code. @app.request_middleware async def check_x_amzn_answer_header(request: Request): # Check that `x-amzn-answer` is 42.