diff --git a/integration_tests/base_routes.py b/integration_tests/base_routes.py index 697f3e18a..e704b6a30 100644 --- a/integration_tests/base_routes.py +++ b/integration_tests/base_routes.py @@ -74,20 +74,27 @@ def return_int_status_code(): return {"status_code": 202, "body": "hello", "type": "text"} -@app.before_request("/") -async def hello_before_request(request): +@app.before_request("/post_with_body") +async def hello_before_request(request,response): global callCount callCount += 1 + print("this is before request") print(request) + print(response) + response["headers"]["test_value"] = "World" return "" -@app.after_request("/") -async def hello_after_request(request): +@app.after_request("/post_with_body") +async def hello_after_request(request, response): global callCount callCount += 1 + print("this is after request") print(request) - return "" + print(response) + response["body"] = "body modified" + return "hello word" + @app.get("/test/:id") @@ -122,6 +129,9 @@ async def post(): @app.post("/post_with_body") async def postreq_with_body(request): + print("This is the main function") + print(request) + return bytearray(request["body"]).decode("utf-8") @@ -204,4 +214,4 @@ async def redirect_route(request): index_file="index.html", ) app.startup_handler(startup_handler) - app.start(port=5000) + app.start(port=5001) diff --git a/integration_tests/test_middlware.py b/integration_tests/test_middlware.py new file mode 100644 index 000000000..73a0bb747 --- /dev/null +++ b/integration_tests/test_middlware.py @@ -0,0 +1,22 @@ +import requests + + + +BASE_URL = "http://127.0.0.1:5000" + + + +def test_post_with_middleware(session): + + res = requests.post(f"{BASE_URL}/post_with_body", data = { + "hello": "world" + }) + + + + assert res.text=="hello=world" + assert (res.status_code == 200) + + + + diff --git a/robyn/router.py b/robyn/router.py index 8d003551d..b883888e0 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -81,8 +81,8 @@ def __init__(self) -> None: super().__init__() self.routes = [] - def add_route(self, route_type: str, endpoint: str, handler: Callable) -> Callable: - number_of_params = len(signature(handler).parameters) + + def add_route(self, route_type: str, endpoint: str, handler: Callable, number_of_params=0) -> None: self.routes.append( ( route_type, @@ -100,8 +100,9 @@ def add_route(self, route_type: str, endpoint: str, handler: Callable) -> Callab # and returns the arguments. # Arguments are returned as they could be modified by the middlewares. def add_after_request(self, endpoint: str) -> Callable[..., None]: + def inner(handler): - @wraps(handler) + async def async_inner_handler(*args): await handler(*args) return args @@ -111,10 +112,13 @@ def inner_handler(*args): handler(*args) return args + number_of_params = len(signature(handler).parameters) + + if iscoroutinefunction(handler): - self.add_route("AFTER_REQUEST", endpoint, async_inner_handler) + self.add_route("AFTER_REQUEST", endpoint, async_inner_handler, number_of_params) else: - self.add_route("AFTER_REQUEST", endpoint, inner_handler) + self.add_route("AFTER_REQUEST", endpoint, inner_handler, number_of_params) return inner @@ -130,10 +134,12 @@ def inner_handler(*args): handler(*args) return args + number_of_params = len(signature(handler).parameters) + if iscoroutinefunction(handler): - self.add_route("BEFORE_REQUEST", endpoint, async_inner_handler) + self.add_route("BEFORE_REQUEST", endpoint, async_inner_handler, number_of_params) else: - self.add_route("BEFORE_REQUEST", endpoint, inner_handler) + self.add_route("BEFORE_REQUEST", endpoint, inner_handler, number_of_params) return inner diff --git a/src/executors/mod.rs b/src/executors/mod.rs index 498cb5201..041947c16 100644 --- a/src/executors/mod.rs +++ b/src/executors/mod.rs @@ -7,50 +7,46 @@ use std::collections::HashMap; use std::rc::Rc; use std::sync::Arc; -use actix_web::{http::Method, web, HttpRequest}; -use anyhow::{bail, Result}; +use actix_web::HttpResponse; +use anyhow::Result; use log::debug; + use pyo3_asyncio::TaskLocals; // pyO3 module use crate::types::PyFunction; -use futures_util::stream::StreamExt; + use pyo3::prelude::*; use pyo3::types::PyDict; /// @TODO make configurable -const MAX_SIZE: usize = 10_000; pub async fn execute_middleware_function<'a>( function: PyFunction, - payload: &mut web::Payload, + payload: &mut [u8], headers: &HashMap, - req: &HttpRequest, route_params: HashMap, queries: Rc>>, number_of_params: u8, + res: Option<&HttpResponse>, ) -> Result>> { // TODO: // add body in middlewares too - let mut data: Vec = Vec::new(); - - if req.method() == Method::POST - || req.method() == Method::PUT - || req.method() == Method::PATCH - || req.method() == Method::DELETE - { - let mut body = web::BytesMut::new(); - while let Some(chunk) = payload.next().await { - let chunk = chunk?; - // limit max size of in-memory payload - if (body.len() + chunk.len()) > MAX_SIZE { - bail!("Body content Overflow"); - } - body.extend_from_slice(&chunk); - } + let data = payload.to_owned(); + let temp_response = &HttpResponse::Ok().finish(); + + // make response object accessible while creating routes + let response = res.unwrap_or(temp_response); - data = body.to_vec() + debug!("response: {:?}", response); + debug!("temp_response: {:?}", temp_response); + let mut response_headers = HashMap::new(); + for (key, val) in response.headers() { + response_headers.insert(key.to_string(), val.to_str().unwrap().to_string()); } + let mut response_dict: HashMap<&str, Py> = HashMap::new(); + let response_status_code = response.status().as_u16(); + let response_body = data.clone(); // request object accessible while creating routes let mut request = HashMap::new(); @@ -61,22 +57,28 @@ pub async fn execute_middleware_function<'a>( queries_clone.insert(key, value); } - match function { + let http_response = match function { PyFunction::CoRoutine(handler) => { let output = Python::with_gil(|py| { let handler = handler.as_ref(py); + request.insert("params", route_params.into_py(py)); request.insert("queries", queries_clone.into_py(py)); // is this a bottleneck again? request.insert("headers", headers.clone().into_py(py)); - // request.insert("body", data.into_py(py)); - + request.insert("body", data.into_py(py)); + response_dict.insert("headers", response_headers.into_py(py)); + response_dict.insert("status", response_status_code.into_py(py)); + response_dict.insert("body", response_body.into_py(py)); + debug!("response_dict: {:?}", response_dict); + debug!("res: {:?}", res); // this makes the request object to be accessible across every route let coro: PyResult<&PyAny> = match number_of_params { 0 => handler.call0(), - 1 => handler.call1((request,)), + 1 => handler.call1((response_dict,)), + 2 => handler.call1((request, response_dict)), // this is done to accomodate any future params - 2_u8..=u8::MAX => handler.call1((request,)), + 3_u8..=u8::MAX => handler.call1((request, response_dict)), }; pyo3_asyncio::tokio::into_future(coro?) })?; @@ -90,7 +92,7 @@ pub async fn execute_middleware_function<'a>( let responses = output[0].clone(); Ok(responses) })?; - + debug!("res at 97 : {:?}", res); Ok(res) } @@ -105,12 +107,16 @@ pub async fn execute_middleware_function<'a>( // is this a bottleneck again? request.insert("headers", headers.clone().into_py(py)); request.insert("body", data.into_py(py)); + response_dict.insert("headers", response_headers.into_py(py)); + response_dict.insert("status", response_status_code.into_py(py)); + response_dict.insert("body", response_body.into_py(py)); let output: PyResult<&PyAny> = match number_of_params { 0 => handler.call0(), 1 => handler.call1((request,)), + 2 => handler.call1((request, response_dict)), // this is done to accomodate any future params - 2_u8..=u8::MAX => handler.call1((request,)), + 3_u8..=u8::MAX => handler.call1((request, response_dict)), }; let output: Vec>> = output?.extract()?; @@ -120,7 +126,12 @@ pub async fn execute_middleware_function<'a>( Ok(output?) } - } + }; + + // + let &mut original_headers = response.headers_mut(); + + http_response } pub async fn execute_function( @@ -184,34 +195,15 @@ pub async fn execute_function( #[inline] pub async fn execute_http_function( function: PyFunction, - payload: &mut web::Payload, + payload: &mut [u8], headers: HashMap, - req: &HttpRequest, route_params: HashMap, queries: Rc>>, number_of_params: u8, // need to change this to return a response struct // create a custom struct for this ) -> Result> { - let mut data: Vec = Vec::new(); - - if req.method() == Method::POST - || req.method() == Method::PUT - || req.method() == Method::PATCH - || req.method() == Method::DELETE - { - let mut body = web::BytesMut::new(); - while let Some(chunk) = payload.next().await { - let chunk = chunk?; - // limit max size of in-memory payload - if (body.len() + chunk.len()) > MAX_SIZE { - bail!("Body content Overflow"); - } - body.extend_from_slice(&chunk); - } - - data = body.to_vec() - } + let data: Vec = payload.to_owned(); // request object accessible while creating routes let mut request = HashMap::new(); diff --git a/src/request_handler/mod.rs b/src/request_handler/mod.rs index dc7601d03..c2c94b83c 100644 --- a/src/request_handler/mod.rs +++ b/src/request_handler/mod.rs @@ -5,7 +5,7 @@ use std::rc::Rc; use std::str::FromStr; use std::{cell::RefCell, collections::HashMap}; -use actix_web::{web, HttpRequest, HttpResponse, HttpResponseBuilder}; +use actix_web::{HttpResponse, HttpResponseBuilder}; // pyO3 module use crate::types::PyFunction; @@ -31,8 +31,7 @@ pub async fn handle_http_request( function: PyFunction, number_of_params: u8, headers: HashMap, - payload: &mut web::Payload, - req: &HttpRequest, + payload: &mut [u8], route_params: HashMap, queries: Rc>>, ) -> HttpResponse { @@ -40,7 +39,6 @@ pub async fn handle_http_request( function, payload, headers.clone(), - req, route_params, queries, number_of_params, @@ -56,56 +54,28 @@ pub async fn handle_http_request( } }; - let body = contents.get("body").unwrap().to_owned(); - let status_code = - actix_http::StatusCode::from_str(contents.get("status_code").unwrap()).unwrap(); + // removed the response creation - let response_headers: HashMap = match contents.get("headers") { - Some(headers) => { - let h: HashMap = serde_json::from_str(headers).unwrap(); - h - } - None => HashMap::new(), - }; - - debug!( - "These are the request headers from serde {:?}", - response_headers - ); - - let mut response = HttpResponse::build(status_code); - apply_headers(&mut response, response_headers); - let final_response = if !body.is_empty() { - response.body(body) - } else { - response.finish() - }; - - debug!( - "The response status code is {} and the headers are {:?}", - final_response.status(), - final_response.headers() - ); - final_response + contents } pub async fn handle_http_middleware_request( function: PyFunction, number_of_params: u8, headers: &HashMap, - payload: &mut web::Payload, - req: &HttpRequest, + payload: &mut [u8], route_params: HashMap, queries: Rc>>, + res: Option<&HttpResponse>, ) -> HashMap> { let contents = match execute_middleware_function( function, payload, headers, - req, route_params, queries, number_of_params, + res, ) .await { diff --git a/src/server.rs b/src/server.rs index f924b79ed..bf590c318 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,6 @@ use crate::io_helpers::apply_headers; use crate::request_handler::{handle_http_middleware_request, handle_http_request}; use crate::routers::const_router::ConstRouter; - use crate::routers::router::Router; use crate::routers::types::MiddlewareRoute; use crate::routers::{middleware_router::MiddlewareRouter, web_socket_router::WebSocketRouter}; @@ -23,9 +22,10 @@ use std::thread; use actix_files::Files; use actix_http::header::HeaderMap; -use actix_http::KeepAlive; +use actix_http::{KeepAlive, Method}; use actix_web::*; use dashmap::DashMap; +use futures_util::stream::StreamExt; // pyO3 module use log::debug; @@ -298,11 +298,14 @@ impl Server { number_of_params: u8, ) { debug!("MiddleWare Route added for {} {} ", route_type, route); - - let route_type = MiddlewareRoute::from_str(route_type); - self.middleware_router - .add_route(route_type, route, handler, is_async, number_of_params) + .add_route( + MiddlewareRoute::from_str(route_type), + route, + handler, + is_async, + number_of_params, + ) .unwrap(); } @@ -369,6 +372,40 @@ async fn merge_headers( headers } +async fn create_response(contents: HashMap) -> HTTPResponse { + let body = contents.get("body").unwrap().to_owned(); + let status_code = + actix_http::StatusCode::from_str(contents.get("status_code").unwrap()).unwrap(); + + let response_headers: HashMap = match contents.get("headers") { + Some(headers) => { + let h: HashMap = serde_json::from_str(headers).unwrap(); + h + } + None => HashMap::new(), + }; + + debug!( + "These are the request headers from serde {:?}", + response_headers + ); + + let mut response = HttpResponse::build(status_code); + apply_headers(&mut response, response_headers); + let final_response = if !body.is_empty() { + response.body(body) + } else { + response.finish() + }; + + debug!( + "The response status code is {} and the headers are {:?}", + final_response.status(), + final_response.headers() + ); + final_response +} + /// This is our service handler. It receives a Request, routes on it /// path, and returns a Future of a Response. async fn index( @@ -392,19 +429,38 @@ async fn index( } let headers = merge_headers(&global_headers, req.headers()).await; + const MAX_SIZE: usize = 10_000; + let mut data: Vec = Vec::new(); + + if req.method() == Method::POST + || req.method() == Method::PUT + || req.method() == Method::PATCH + || req.method() == Method::DELETE + { + let mut body = web::BytesMut::new(); + while let Some(chunk) = payload.next().await { + let chunk = chunk.unwrap(); + // limit max size of in-memory payload + if (body.len() + chunk.len()) > MAX_SIZE { + return HttpResponse::PayloadTooLarge().finish(); + } + body.extend_from_slice(&chunk); + } - // need a better name for this - let tuple_params = + data = body.to_vec() + } + + let inital_contents = match middleware_router.get_route(MiddlewareRoute::BeforeRequest, req.uri().path()) { Some(((handler_function, number_of_params), route_params)) => { let x = handle_http_middleware_request( handler_function, number_of_params, &headers, - &mut payload, - &req, + &mut data, route_params, queries.clone(), + None, ) .await; debug!("Middleware contents {:?}", x); @@ -412,18 +468,17 @@ async fn index( } None => HashMap::new(), }; + debug!("These are the tuple params {:?}", inital_contents); - debug!("These are the tuple params {:?}", tuple_params); - - let headers_dup = if !tuple_params.is_empty() { - tuple_params.get("headers").unwrap().clone() + let headers_dup = if !inital_contents.is_empty() { + inital_contents.get("headers").unwrap().clone() } else { headers }; debug!("These are the request headers {:?}", headers_dup); - let response = if const_router + let contents = if const_router .get_route(req.method().clone(), req.uri().path()) .is_some() { @@ -441,8 +496,7 @@ async fn index( handler_function, number_of_params, headers_dup.clone(), - &mut payload, - &req, + &mut data, route_params, queries.clone(), ) @@ -456,21 +510,28 @@ async fn index( } }; - if let Some(((handler_function, number_of_params), route_params)) = + let final_contents = if let Some(((handler_function, number_of_params), route_params)) = middleware_router.get_route(MiddlewareRoute::AfterRequest, req.uri().path()) { let x = handle_http_middleware_request( handler_function, number_of_params, &headers_dup, - &mut payload, - &req, + &mut data, route_params, queries.clone(), + Some(&contents), ) .await; - debug!("{:?}", x); + debug!("this is the response from the after request {:?}", x); + x + } else { + HashMap::new() }; - response + if !final_contents.is_empty() { + create_response(final_contents) + } else { + create_response(contents) + } }